带有字符串 dtype 的 Numpy 鸭子数组意外抛出“numpy.core._exceptions._UFuncNoLoopError”

Numpy duck array with string dtype unexpectedly throws `numpy.core._exceptions._UFuncNoLoopError`

这是我一直用于数字数据的简单 numpy 鸭子数组的最小工作示例。

import numpy as np

class DuckArray(np.lib.mixins.NDArrayOperatorsMixin):

    def __init__(self, array: np.ndarray):
        self.array = array

    def __repr__(self):
        return f'DuckArray({self.array})'

    def __array_ufunc__(self, function, method, *inputs, **kwargs):

        # Normalize inputs
        inputs = [inp.array if isinstance(inp, type(self)) else inp for inp in inputs]

        # Loop through inputs until we find a valid implementation
        for inp in inputs:
            result = inp.__array_ufunc__(function, method, *inputs, **kwargs)
            if result is not NotImplemented:
                return type(self)(result)

            return NotImplemented

这个 class 的真实版本有一个 __array_function__ 的实现 也一样,不过这题只涉及__array_ufunc__.

如我们所见,此实现适用于数字数据类型。

In [1]: a = DuckArray(np.array([1, 2, 3]))
In [2]: a + 2
Out[2]: DuckArray([3 4 5])
In [3]: a == 2
Out[3]: DuckArray([False  True False])

但如果数组是字符串 dtype

,它会失败并显示 numpy.core._exceptions._UFuncNoLoopError
In [4]: b = DuckArray(np.array(['abc', 'def', 'ghi']))
In [5]: b == 'def'
Traceback (most recent call last):
  File "C:\Users\byrdie\AppData\Local\Programs\Python\Python38\lib\site-packages\IPython\core\interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-c5975227701e>", line 1, in <module>
    b == 'def'
  File "C:\Users\byrdie\AppData\Local\Programs\Python\Python38\lib\site-packages\numpy\lib\mixins.py", line 21, in func
    return ufunc(self, other)
  File "<ipython-input-2-aced4bbdd318>", line 15, in __array_ufunc__
    result = inp.__array_ufunc__(function, method, *inputs, **kwargs)
numpy.core._exceptions._UFuncNoLoopError: ufunc 'equal' did not contain a loop with signature matching types (dtype('<U3'), dtype('<U3')) -> dtype('bool')

即使相同的操作显然适用于原始数组。

In [6]: b.array == 'def'
Out[6]: array([False,  True, False])

这告诉我 ufunc 循环确实存在,但显然有些地方出错了。

有人知道我哪里错了吗?

当你创建一个 numpy 字符串数组时,每个字符串的 dtype 默认为 <Un 其中 n 是它的长度

np.array(['abc', 'defg'])[0].dtype
>> dtype('<U3')
np.array(['abc', 'defg'])[1].dtype
>> dtype('<U4')

np.equal ufunc 不支持比较 <Un dtype,因此使用它比较 'abc''def' 中的两个 <U3 会出错。

要修复它,请在创建字符串数组时将 dtype 显式声明为 object

DuckArray(np.array(['abc', 'def'], dtype=object)) == 'abc'
>> DuckArray([ True False])