使用子对象的类型初始化父对象 class 中的对象

Initializing an object in a parent class with the type of the child

我写了一个父 class,我在其中定义了一些函数(示例 __mul__、__truediv__ 等)所有子 class 都应该有.这些函数在执行后应保持子 class 的类型。

这里有一些代码来解释我的意思:

class Magnet():
    
    def __init__(self, length, strength):
        
        self.length = length
        self.strength = strength
        
        return
    
    def __mul__(self, other):

        if np.isscalar(other):
            return Magnet(self.length, self.strength * other)

        else:
            return NotImplemented
        
class Quadrupole(Magnet):
    
    def __init__(self, length, strength, name):
        
        super().__init__(length, strength)
        
        self.name = name
        
        return

现在如果我这样做:

Quad1 = Quadrupole(2, 10, 'Q1')
Quad2 = Quad1 * 2

那么 Quad1 的类型为“__main__.Quadrupole”,Quad2 的类型为“__main__.Magnet”。

我想知道如何做到这一点,以便保留子类型而不是将其重铸为父类型。一种解决方案是在子 classes 中重新定义这些函数并更改

        if np.isscalar(other):
            return Magnet(self.length, self.strength * other)

        if np.isscalar(other):
            return Quadrupole(self.length, self.strength * other)

但进行继承的主要原因是不复制粘贴代码。也许像 super() 但向下或者 class 类型的一些占位符...

非常感谢您的帮助。

采用的解决方案

正在使用

return type(self)(self.length, self.strength * other)

很有魅力。它引发了一个错误,因为我忘记在 Magnet.__init__() 中添加 'name' 参数(我的原始代码确实如此,但是在简化示例时搞砸了)。

我在这里也发现了同样的问题:Returning object of same subclass in __add__ operator

解决方案 1

您可以使用 type(self) 获取类型并从中创建一个新对象。

def __mul__(self, other):
    if np.isscalar(other):
        return type(self)(self.length, self.strength * other)
    raise NotImplemented

(也提出了 NotImplemented 而不是返回它。)

现在使用您的代码将导致:

    return type(self)(self.length, self.strength * other)
TypeError: __init__() missing 1 required positional argument: 'name'

这需要 Quadrupolename 设置默认参数。

class Quadrupole(Magnet):
    def __init__(self, length, strength, name='unknown'):
        super().__init__(length, strength)
        self.name = name

你写代码很开心,但你可能不开心。原因是您现在丢失了有关 Quadrupole class.

name 的信息

解决方案 2

您正在返回 class 的新实例,有时这不是必需的,您可以只改变旧的 class。这会将您的代码简化为:

def __mul__(self, other):
    if np.isscalar(other):
        self.strength *= other.strength
        return self
    raise NotImplemented

这会改变您的旧实例。

解决方案 3

解决方案 1 的主要问题是您丢失了信息,因为您正在创建一个新的 class。现在一个可能的替代方案是只复制 class。不幸的是,复制 class 并不总是那么简单。

基于 SO 问题,在这种情况下可能使用 deepcopy,但如果你有一个复杂的 class 结构,你可能必须实现 __copy__得到你想要的。

def __mul__(self, other):
    if np.isscalar(other):
        class_copy = deepcopy(self)
        class_copy.strength *= other
        return class_copy
    raise NotImplemented

您可以选择提供 __copy__ 方法。对于提供的代码片段,这不是必需的,但在更复杂的情况下可能是必需的。

def __copy__(self):
    return Quadrupole(self.length, self.strength, self.name)