numba 向量化中的 ValueError for accumulate

ValueError in numba vectorize for accumulate

我正在尝试使用 Numba 编写一个 ufunc。我阅读 this 并将其合并到我的代码中。所以,我运行的基本代码是

import numpy as np
arr = np.arange(15).reshape((3,5))
def myadd(x, y):
  return x+y
myadd = np.frompyfunc(myadd, 2, 1)
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))

现在如果我使用 Numba 矢量化

import numpy as np
import numba as nb
arr = np.arange(15).reshape((3,5))
@nb.vectorize
def myadd(x, y):
  return x+y
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))

我收到错误

Traceback (most recent call last):

  File "<ipython-input-11-ea9f981e42b2>", line 7, in <module>
    print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))

ValueError: could not find a matching type for myadd.accumulate, requested type has type code 'O'

解决这个问题的方法是什么?我在 Anaconda 的 Spyder 下使用 Numba 0.54.0,Numpy 1.20.3,Python 3.8.10 在 Windows 10.

我认为添加签名和删除 dtype 将解决错误。

对于为什么添加签名有效,我没有确切的答案。但希望这会有所帮助 运行 并帮助您找到答案。 (我的看法,arr是int类型,myadd要float....不确定)

import numpy as np
import numba as nb
arr = np.arange(15).reshape((3,5))
@nb.vectorize([(nb.int64, nb.int64)])
def myadd(x, y):
  return x+y
print(myadd.accumulate(arr, axis=1).astype(int))