numba:函数重载

numba: overload of function

我正在试用 numba,据说 python 包可以让我的 nparray 超级快。我想在非 python 模式下 运行 我的函数。它本质上做的是接受一个 20x20 数组,为它的每个元素分配随机数,计算它的逆矩阵,然后 return 它。 但问题是,当我用 np.zeros() 初始化数组 result 时,我的脚本崩溃并给我一条错误消息 'overload of function zeros'。 有人可以告诉我发生了什么事吗?非常感谢。

from numba import njit
import time
import numpy as np
import random

arr = np.zeros((20,20),dtype = float)
@njit
def aFunctionWithNumba (incomingArray):
    result = np.zeros(np.shape(incomingArray), dtype = float)
    for i in range(len(incomingArray[0])):
        for j in range(len(incomingArray[1])):
            incomingArray[i,j] = random.randrange(105150,1541586)
    result = np.linalg.inv(incomingArray)
    return result

t0 = time.time()
fastArray = aFunctionWithNumba(arr)
t1 = time.time()
s1 = t1 - t0

完整的错误信息如下:

Exception has occurred: TypingError Failed in nopython mode pipeline (step: nopython frontend) No implementation of function Function(<built-in function zeros>) found for signature:
 
 >>> zeros(UniTuple(int64 x 2), dtype=Function(<class 'float'>))   There are 2 candidate implementations:
  - Of which 2 did not match due to:   Overload of function 'zeros': File: numba\core\typing\npydecl.py: Line 511.
    With argument(s): '(UniTuple(int64 x 2), dtype=Function(<class 'float'>))':    No match.

During: resolving callee type: Function(<built-in function zeros>) During: typing of call at c:\Users\Eric\Desktop\testNumba.py (9)


File "testNumba.py", line 9: def aFunctionWithNumba (incomingArray):
    result = np.zeros(np.shape(incomingArray), dtype = float)
    ^   File "C:\Users\Eric\Desktop\testNumba.py", line 25, in <module>
    fastArray = aFunctionWithNumba(arr)

错误

您应该在 JITted 函数中使用 Numpy 或 Numba 类型。

更改以下行您的代码有效:

result = np.zeros(np.shape(incomingArray), dtype=np.float64)

但是您的代码将更通用:

result = np.zeros(incomingArray.shape, dtype=incomingArray.dtype)

或者,甚至更好:

result = np.zeros_like(incomingArray)

时机

第一次调用 JITted 函数时,编译它需要一些时间,比执行它所花费的时间要长得多。所以你应该在做任何计时之前用相同的参数类型调用一次函数。

进一步优化

如果您有兴趣比较使用或不使用 Numba 的嵌套循环的执行时间,您的代码很好。否则你可以用类似的东西替换循环:

incomingArray[:] = np.random.random(incomingArray.shape) * (1541586 - 105150) + 105150