如何比较持有 numpy.ndarray 的数据类的相等性(bool(a==b) 引发 ValueError)?
How to compare equality of dataclasses holding numpy.ndarray (bool(a==b) raises ValueError)?
如果我创建一个包含 Numpy ndarray 的 Python 数据类,我将无法再使用自动生成的 __eq__
。
import numpy as np
@dataclass
class Instr:
foo: np.ndarray
bar: np.ndarray
arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
这是因为 ndarray.__eq__
有时 returns 一个 ndarray
真值,通过比较 a[0]
和 b[0]
,依此类推,直到 2 中的较长者。这非常复杂且不直观,实际上只有当数组形状不同或具有不同值时才会引发错误。
我如何安全地比较持有 Numpy 数组的 @dataclass
es?
@dataclass
对 __eq__
的实现是使用 eval()
生成的。它的来源在堆栈跟踪中丢失,无法使用 inspect
查看,但它实际上使用 元组比较 ,它调用 bool(foo)。
import dis
dis.dis(Instr.__eq__)
摘录:
3 12 LOAD_FAST 0 (self)
14 LOAD_ATTR 1 (foo)
16 LOAD_FAST 0 (self)
18 LOAD_ATTR 2 (bar)
20 BUILD_TUPLE 2
22 LOAD_FAST 1 (other)
24 LOAD_ATTR 1 (foo)
26 LOAD_FAST 1 (other)
28 LOAD_ATTR 2 (bar)
30 BUILD_TUPLE 2
32 COMPARE_OP 2 (==)
34 RETURN_VALUE
解决方案是放入你自己的 __eq__
方法并设置 eq=False
这样数据类就不会生成它自己的(尽管检查 docs 最后一步不是有必要,但我认为无论如何都要明确。
import numpy as np
def array_eq(arr1, arr2):
return (isinstance(arr1, np.ndarray) and
isinstance(arr2, np.ndarray) and
arr1.shape == arr2.shape and
(arr1 == arr2).all())
@dataclass(eq=False)
class Instr:
foo: np.ndarray
bar: np.ndarray
def __eq__(self, other):
if not isinstance(other, Instr):
return NotImplemented
return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)
编辑
通用数据类的通用快速解决方案,其中一些值是 numpy 数组而另一些不是
import numpy as np
from dataclasses import dataclass, astuple
def array_safe_eq(a, b) -> bool:
"""Check if a and b are equal, even if they are numpy arrays"""
if a is b:
return True
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return a.shape == b.shape and (a == b).all()
try:
return a == b
except TypeError:
return NotImplemented
def dc_eq(dc1, dc2) -> bool:
"""checks if two dataclasses which hold numpy arrays are equal"""
if dc1 is dc2:
return True
if dc1.__class__ is not dc2.__class__:
return NotImplmeneted # better than False
t1 = astuple(dc1)
t2 = astuple(dc2)
return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))
# usage
@dataclass(eq=False)
class T:
a: int
b: np.ndarray
c: np.ndarray
def __eq__(self, other):
return dc_eq(self, other)
如果我创建一个包含 Numpy ndarray 的 Python 数据类,我将无法再使用自动生成的 __eq__
。
import numpy as np
@dataclass
class Instr:
foo: np.ndarray
bar: np.ndarray
arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
这是因为 ndarray.__eq__
有时 returns 一个 ndarray
真值,通过比较 a[0]
和 b[0]
,依此类推,直到 2 中的较长者。这非常复杂且不直观,实际上只有当数组形状不同或具有不同值时才会引发错误。
我如何安全地比较持有 Numpy 数组的 @dataclass
es?
@dataclass
对 __eq__
的实现是使用 eval()
生成的。它的来源在堆栈跟踪中丢失,无法使用 inspect
查看,但它实际上使用 元组比较 ,它调用 bool(foo)。
import dis
dis.dis(Instr.__eq__)
摘录:
3 12 LOAD_FAST 0 (self) 14 LOAD_ATTR 1 (foo) 16 LOAD_FAST 0 (self) 18 LOAD_ATTR 2 (bar) 20 BUILD_TUPLE 2 22 LOAD_FAST 1 (other) 24 LOAD_ATTR 1 (foo) 26 LOAD_FAST 1 (other) 28 LOAD_ATTR 2 (bar) 30 BUILD_TUPLE 2 32 COMPARE_OP 2 (==) 34 RETURN_VALUE
解决方案是放入你自己的 __eq__
方法并设置 eq=False
这样数据类就不会生成它自己的(尽管检查 docs 最后一步不是有必要,但我认为无论如何都要明确。
import numpy as np
def array_eq(arr1, arr2):
return (isinstance(arr1, np.ndarray) and
isinstance(arr2, np.ndarray) and
arr1.shape == arr2.shape and
(arr1 == arr2).all())
@dataclass(eq=False)
class Instr:
foo: np.ndarray
bar: np.ndarray
def __eq__(self, other):
if not isinstance(other, Instr):
return NotImplemented
return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)
编辑
通用数据类的通用快速解决方案,其中一些值是 numpy 数组而另一些不是
import numpy as np
from dataclasses import dataclass, astuple
def array_safe_eq(a, b) -> bool:
"""Check if a and b are equal, even if they are numpy arrays"""
if a is b:
return True
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return a.shape == b.shape and (a == b).all()
try:
return a == b
except TypeError:
return NotImplemented
def dc_eq(dc1, dc2) -> bool:
"""checks if two dataclasses which hold numpy arrays are equal"""
if dc1 is dc2:
return True
if dc1.__class__ is not dc2.__class__:
return NotImplmeneted # better than False
t1 = astuple(dc1)
t2 = astuple(dc2)
return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))
# usage
@dataclass(eq=False)
class T:
a: int
b: np.ndarray
c: np.ndarray
def __eq__(self, other):
return dc_eq(self, other)