使用子对象的类型初始化父对象 class 中的对象
Initializing an object in a parent class with the type of the child
我写了一个父 class,我在其中定义了一些函数(示例 __mul__、__truediv__ 等)所有子 class 都应该有.这些函数在执行后应保持子 class 的类型。
这里有一些代码来解释我的意思:
class Magnet():
def __init__(self, length, strength):
self.length = length
self.strength = strength
return
def __mul__(self, other):
if np.isscalar(other):
return Magnet(self.length, self.strength * other)
else:
return NotImplemented
class Quadrupole(Magnet):
def __init__(self, length, strength, name):
super().__init__(length, strength)
self.name = name
return
现在如果我这样做:
Quad1 = Quadrupole(2, 10, 'Q1')
Quad2 = Quad1 * 2
那么 Quad1 的类型为“__main__.Quadrupole”,Quad2 的类型为“__main__.Magnet”。
我想知道如何做到这一点,以便保留子类型而不是将其重铸为父类型。一种解决方案是在子 classes 中重新定义这些函数并更改
if np.isscalar(other):
return Magnet(self.length, self.strength * other)
至
if np.isscalar(other):
return Quadrupole(self.length, self.strength * other)
但进行继承的主要原因是不复制粘贴代码。也许像 super() 但向下或者 class 类型的一些占位符...
非常感谢您的帮助。
采用的解决方案
正在使用
return type(self)(self.length, self.strength * other)
很有魅力。它引发了一个错误,因为我忘记在 Magnet.__init__() 中添加 'name' 参数(我的原始代码确实如此,但是在简化示例时搞砸了)。
我在这里也发现了同样的问题:Returning object of same subclass in __add__ operator
解决方案 1
您可以使用 type(self)
获取类型并从中创建一个新对象。
def __mul__(self, other):
if np.isscalar(other):
return type(self)(self.length, self.strength * other)
raise NotImplemented
(也提出了 NotImplemented
而不是返回它。)
现在使用您的代码将导致:
return type(self)(self.length, self.strength * other)
TypeError: __init__() missing 1 required positional argument: 'name'
这需要 Quadrupole
为 name
设置默认参数。
class Quadrupole(Magnet):
def __init__(self, length, strength, name='unknown'):
super().__init__(length, strength)
self.name = name
你写代码很开心,但你可能不开心。原因是您现在丢失了有关 Quadrupole
class.
的 name
的信息
解决方案 2
您正在返回 class 的新实例,有时这不是必需的,您可以只改变旧的 class。这会将您的代码简化为:
def __mul__(self, other):
if np.isscalar(other):
self.strength *= other.strength
return self
raise NotImplemented
这会改变您的旧实例。
解决方案 3
解决方案 1 的主要问题是您丢失了信息,因为您正在创建一个新的 class。现在一个可能的替代方案是只复制 class。不幸的是,复制 class 并不总是那么简单。
基于 SO 问题,在这种情况下可能使用 deepcopy
,但如果你有一个复杂的 class 结构,你可能必须实现 __copy__
得到你想要的。
def __mul__(self, other):
if np.isscalar(other):
class_copy = deepcopy(self)
class_copy.strength *= other
return class_copy
raise NotImplemented
您可以选择提供 __copy__
方法。对于提供的代码片段,这不是必需的,但在更复杂的情况下可能是必需的。
def __copy__(self):
return Quadrupole(self.length, self.strength, self.name)
我写了一个父 class,我在其中定义了一些函数(示例 __mul__、__truediv__ 等)所有子 class 都应该有.这些函数在执行后应保持子 class 的类型。
这里有一些代码来解释我的意思:
class Magnet():
def __init__(self, length, strength):
self.length = length
self.strength = strength
return
def __mul__(self, other):
if np.isscalar(other):
return Magnet(self.length, self.strength * other)
else:
return NotImplemented
class Quadrupole(Magnet):
def __init__(self, length, strength, name):
super().__init__(length, strength)
self.name = name
return
现在如果我这样做:
Quad1 = Quadrupole(2, 10, 'Q1')
Quad2 = Quad1 * 2
那么 Quad1 的类型为“__main__.Quadrupole”,Quad2 的类型为“__main__.Magnet”。
我想知道如何做到这一点,以便保留子类型而不是将其重铸为父类型。一种解决方案是在子 classes 中重新定义这些函数并更改
if np.isscalar(other):
return Magnet(self.length, self.strength * other)
至
if np.isscalar(other):
return Quadrupole(self.length, self.strength * other)
但进行继承的主要原因是不复制粘贴代码。也许像 super() 但向下或者 class 类型的一些占位符...
非常感谢您的帮助。
采用的解决方案
正在使用
return type(self)(self.length, self.strength * other)
很有魅力。它引发了一个错误,因为我忘记在 Magnet.__init__() 中添加 'name' 参数(我的原始代码确实如此,但是在简化示例时搞砸了)。
我在这里也发现了同样的问题:Returning object of same subclass in __add__ operator
解决方案 1
您可以使用 type(self)
获取类型并从中创建一个新对象。
def __mul__(self, other):
if np.isscalar(other):
return type(self)(self.length, self.strength * other)
raise NotImplemented
(也提出了 NotImplemented
而不是返回它。)
现在使用您的代码将导致:
return type(self)(self.length, self.strength * other)
TypeError: __init__() missing 1 required positional argument: 'name'
这需要 Quadrupole
为 name
设置默认参数。
class Quadrupole(Magnet):
def __init__(self, length, strength, name='unknown'):
super().__init__(length, strength)
self.name = name
你写代码很开心,但你可能不开心。原因是您现在丢失了有关 Quadrupole
class.
name
的信息
解决方案 2
您正在返回 class 的新实例,有时这不是必需的,您可以只改变旧的 class。这会将您的代码简化为:
def __mul__(self, other):
if np.isscalar(other):
self.strength *= other.strength
return self
raise NotImplemented
这会改变您的旧实例。
解决方案 3
解决方案 1 的主要问题是您丢失了信息,因为您正在创建一个新的 class。现在一个可能的替代方案是只复制 class。不幸的是,复制 class 并不总是那么简单。
基于 deepcopy
,但如果你有一个复杂的 class 结构,你可能必须实现 __copy__
得到你想要的。
def __mul__(self, other):
if np.isscalar(other):
class_copy = deepcopy(self)
class_copy.strength *= other
return class_copy
raise NotImplemented
您可以选择提供 __copy__
方法。对于提供的代码片段,这不是必需的,但在更复杂的情况下可能是必需的。
def __copy__(self):
return Quadrupole(self.length, self.strength, self.name)