带有字符串 dtype 的 Numpy 鸭子数组意外抛出“numpy.core._exceptions._UFuncNoLoopError”
Numpy duck array with string dtype unexpectedly throws `numpy.core._exceptions._UFuncNoLoopError`
这是我一直用于数字数据的简单 numpy 鸭子数组的最小工作示例。
import numpy as np
class DuckArray(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, array: np.ndarray):
self.array = array
def __repr__(self):
return f'DuckArray({self.array})'
def __array_ufunc__(self, function, method, *inputs, **kwargs):
# Normalize inputs
inputs = [inp.array if isinstance(inp, type(self)) else inp for inp in inputs]
# Loop through inputs until we find a valid implementation
for inp in inputs:
result = inp.__array_ufunc__(function, method, *inputs, **kwargs)
if result is not NotImplemented:
return type(self)(result)
return NotImplemented
这个 class 的真实版本有一个 __array_function__
的实现
也一样,不过这题只涉及__array_ufunc__
.
如我们所见,此实现适用于数字数据类型。
In [1]: a = DuckArray(np.array([1, 2, 3]))
In [2]: a + 2
Out[2]: DuckArray([3 4 5])
In [3]: a == 2
Out[3]: DuckArray([False True False])
但如果数组是字符串 dtype
,它会失败并显示 numpy.core._exceptions._UFuncNoLoopError
In [4]: b = DuckArray(np.array(['abc', 'def', 'ghi']))
In [5]: b == 'def'
Traceback (most recent call last):
File "C:\Users\byrdie\AppData\Local\Programs\Python\Python38\lib\site-packages\IPython\core\interactiveshell.py", line 3441, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-6-c5975227701e>", line 1, in <module>
b == 'def'
File "C:\Users\byrdie\AppData\Local\Programs\Python\Python38\lib\site-packages\numpy\lib\mixins.py", line 21, in func
return ufunc(self, other)
File "<ipython-input-2-aced4bbdd318>", line 15, in __array_ufunc__
result = inp.__array_ufunc__(function, method, *inputs, **kwargs)
numpy.core._exceptions._UFuncNoLoopError: ufunc 'equal' did not contain a loop with signature matching types (dtype('<U3'), dtype('<U3')) -> dtype('bool')
即使相同的操作显然适用于原始数组。
In [6]: b.array == 'def'
Out[6]: array([False, True, False])
这告诉我 ufunc 循环确实存在,但显然有些地方出错了。
有人知道我哪里错了吗?
当你创建一个 numpy 字符串数组时,每个字符串的 dtype
默认为 <Un
其中 n
是它的长度
np.array(['abc', 'defg'])[0].dtype
>> dtype('<U3')
np.array(['abc', 'defg'])[1].dtype
>> dtype('<U4')
np.equal
ufunc 不支持比较 <Un
dtype,因此使用它比较 'abc'
和 'def'
中的两个 <U3
会出错。
要修复它,请在创建字符串数组时将 dtype
显式声明为 object
。
DuckArray(np.array(['abc', 'def'], dtype=object)) == 'abc'
>> DuckArray([ True False])
这是我一直用于数字数据的简单 numpy 鸭子数组的最小工作示例。
import numpy as np
class DuckArray(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, array: np.ndarray):
self.array = array
def __repr__(self):
return f'DuckArray({self.array})'
def __array_ufunc__(self, function, method, *inputs, **kwargs):
# Normalize inputs
inputs = [inp.array if isinstance(inp, type(self)) else inp for inp in inputs]
# Loop through inputs until we find a valid implementation
for inp in inputs:
result = inp.__array_ufunc__(function, method, *inputs, **kwargs)
if result is not NotImplemented:
return type(self)(result)
return NotImplemented
这个 class 的真实版本有一个 __array_function__
的实现
也一样,不过这题只涉及__array_ufunc__
.
如我们所见,此实现适用于数字数据类型。
In [1]: a = DuckArray(np.array([1, 2, 3]))
In [2]: a + 2
Out[2]: DuckArray([3 4 5])
In [3]: a == 2
Out[3]: DuckArray([False True False])
但如果数组是字符串 dtype
,它会失败并显示numpy.core._exceptions._UFuncNoLoopError
In [4]: b = DuckArray(np.array(['abc', 'def', 'ghi']))
In [5]: b == 'def'
Traceback (most recent call last):
File "C:\Users\byrdie\AppData\Local\Programs\Python\Python38\lib\site-packages\IPython\core\interactiveshell.py", line 3441, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-6-c5975227701e>", line 1, in <module>
b == 'def'
File "C:\Users\byrdie\AppData\Local\Programs\Python\Python38\lib\site-packages\numpy\lib\mixins.py", line 21, in func
return ufunc(self, other)
File "<ipython-input-2-aced4bbdd318>", line 15, in __array_ufunc__
result = inp.__array_ufunc__(function, method, *inputs, **kwargs)
numpy.core._exceptions._UFuncNoLoopError: ufunc 'equal' did not contain a loop with signature matching types (dtype('<U3'), dtype('<U3')) -> dtype('bool')
即使相同的操作显然适用于原始数组。
In [6]: b.array == 'def'
Out[6]: array([False, True, False])
这告诉我 ufunc 循环确实存在,但显然有些地方出错了。
有人知道我哪里错了吗?
当你创建一个 numpy 字符串数组时,每个字符串的 dtype
默认为 <Un
其中 n
是它的长度
np.array(['abc', 'defg'])[0].dtype
>> dtype('<U3')
np.array(['abc', 'defg'])[1].dtype
>> dtype('<U4')
np.equal
ufunc 不支持比较 <Un
dtype,因此使用它比较 'abc'
和 'def'
中的两个 <U3
会出错。
要修复它,请在创建字符串数组时将 dtype
显式声明为 object
。
DuckArray(np.array(['abc', 'def'], dtype=object)) == 'abc'
>> DuckArray([ True False])