如何简化重复的算术函数?
How can I simplify my repetitive arithmetic functions?
我有一个 class 包含一些复杂向量公式的计算。详细的结构可能并不重要;它基本上是在网格上计算提供的矢量公式,然后 stores/manages 结果。
如果它们具有相同的时空坐标,我希望能够使用这些 classes 进行基本算术运算。我的问题是每个算术都需要大量的类型检查。这导致 5 个副本(每个 +,-,*,/,**
个)在 3 个位置具有完全相同的代码条和算术符号。
在我看来,面向对象语言中如此多的重复代码看起来很可疑。同时,我也想不出一个优雅的解决方案来简化它,我觉得有一些我不知道的方法。
如何以最佳实践方式提取重复代码?
代码如下,我标出了__sub__
与__add__
的3处不同:
class FieldVector(object):
def __init__(self, formula, fieldparams, meshparams, zero_comps=[]):
[...]
def is_comparable_to(self, other):
"has the same spatio-temporal dimensions as the other one"
if not other.isinstance(FieldVector):
return False
return (
self.meshparams == other.meshparams and
self.framenum == other.framenum and
self.fieldparams.tnull == other.fieldparams.tnull and
self.fieldparams.tmax == other.fieldparams.tmax
)
def _check_comparable(self, other):
if not self.is_comparable_to:
raise ValueError("The two fields have different spatio-temporal coordinates")
def __add__(self, other):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
new_compvals[comp] = self.get_component(comp) + other
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
new_compvals[comp] = self.get_component(comp) + other.get_component(comp)
else:
raise TypeError(f'unsupported operand type(s) for +: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
def __sub__(self, other):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
# --- difference 1: - instead of +
new_compvals[comp] = self.get_component(comp) - other
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
# --- difference 2: - instead of +
new_compvals[comp] = self.get_component(comp) - other.get_component(comp)
else:
# --- difference 3: - instead of +
raise TypeError(f'unsupported operand type(s) for -: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
[... __mul__, __truediv__, __pow__ defined the same way]
您可以将计算提取到私有方法中,然后将 operator 传递给它。
也许是这样的:
import operator
...
def _xeq(self, other, op):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
new_compvals[comp] = op(self.get_component(comp), other)
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
new_compvals[comp] = op(self.get_component(comp), other.get_component(comp))
else:
raise TypeError(f'unsupported operand type(s) for {op}: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
def __add__(self, other):
return self._xeq(other, operator.add)
def __sub__(self, other):
return self._xeq(other, operator.sub)
我有一个 class 包含一些复杂向量公式的计算。详细的结构可能并不重要;它基本上是在网格上计算提供的矢量公式,然后 stores/manages 结果。
如果它们具有相同的时空坐标,我希望能够使用这些 classes 进行基本算术运算。我的问题是每个算术都需要大量的类型检查。这导致 5 个副本(每个 +,-,*,/,**
个)在 3 个位置具有完全相同的代码条和算术符号。
在我看来,面向对象语言中如此多的重复代码看起来很可疑。同时,我也想不出一个优雅的解决方案来简化它,我觉得有一些我不知道的方法。
如何以最佳实践方式提取重复代码?
代码如下,我标出了__sub__
与__add__
的3处不同:
class FieldVector(object):
def __init__(self, formula, fieldparams, meshparams, zero_comps=[]):
[...]
def is_comparable_to(self, other):
"has the same spatio-temporal dimensions as the other one"
if not other.isinstance(FieldVector):
return False
return (
self.meshparams == other.meshparams and
self.framenum == other.framenum and
self.fieldparams.tnull == other.fieldparams.tnull and
self.fieldparams.tmax == other.fieldparams.tmax
)
def _check_comparable(self, other):
if not self.is_comparable_to:
raise ValueError("The two fields have different spatio-temporal coordinates")
def __add__(self, other):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
new_compvals[comp] = self.get_component(comp) + other
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
new_compvals[comp] = self.get_component(comp) + other.get_component(comp)
else:
raise TypeError(f'unsupported operand type(s) for +: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
def __sub__(self, other):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
# --- difference 1: - instead of +
new_compvals[comp] = self.get_component(comp) - other
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
# --- difference 2: - instead of +
new_compvals[comp] = self.get_component(comp) - other.get_component(comp)
else:
# --- difference 3: - instead of +
raise TypeError(f'unsupported operand type(s) for -: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
[... __mul__, __truediv__, __pow__ defined the same way]
您可以将计算提取到私有方法中,然后将 operator 传递给它。
也许是这样的:
import operator
...
def _xeq(self, other, op):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
new_compvals[comp] = op(self.get_component(comp), other)
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
new_compvals[comp] = op(self.get_component(comp), other.get_component(comp))
else:
raise TypeError(f'unsupported operand type(s) for {op}: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
def __add__(self, other):
return self._xeq(other, operator.add)
def __sub__(self, other):
return self._xeq(other, operator.sub)