为什么此代码无法使用 Numba 进行编译?
Why does this code fail to compile with Numba?
我有一个示例代码可以说明我的问题。如果你 运行:
import numpy as np
from numba import jit
@jit(nopython=True)
def test():
arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])
arr2 = arr[:, 0, :]
arr3 = arr2.argsort()
print(arr3)
test()
它会失败:
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of BoundFunction(array.argsort for array(int64, 2d, A)) with parameters ()
During: resolving callee type: BoundFunction(array.argsort for array(int64, 2d, A))
During: typing of call at /home/stark/Work/mmr6/test.py (41)
File "test.py", line 41:
def test():
<source elided>
arr3 = arr2.argsort()
^
argsort 应该在最后一个轴上进行 argsort。基本上它应该给我:
>>>
[[0 1 2]
[0 1 2]]
我认为复制 arr2
数组(使用 copy()
)可以解决问题,因为它会使数组在内存中连续(而不是视图),但它失败并显示相同的消息,除了消息中 arr2
的类型现在是预期的 array(int64, 2d, C)
。
为什么会失败,我该如何解决?
遗憾的是,这是 Numba 当前已知的限制。参见 this issue。目前仅支持一维数组。但是,您的情况有一个简单的解决方法:
import numpy as np
from numba import jit
@jit(nopython=True)
def test():
arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])
arr2 = arr[:, 0, :]
arr3 = np.empty(arr2.shape, dtype=arr2.dtype)
for i in range(arr2.shape[0]):
arr3[i] = arr2[i, :].argsort()
print(arr3)
test()
请注意,即使实现了,也不会更快。参见 this issue。实际上,对于任何给定的 Numpy 基元,Numba 没有理由更快。但是,您可以使用 Numba 手动编写您自己的 Numpy 原语版本,并且有时会由于算法专业化、并行性或数学优化(例如快速数学)而获得加速。当您想执行 Numpy 中 yet/directly 不可用的高效操作时,Numba 通常很棒,并且可以使用循环轻松实现此操作。
实际上,您可以使用 Numba 的 prange
和 JIT 参数 parallel=True
来加快计算速度,假设 argsort
尚未 运行 并行(AFAIK 它应该是连续的)。这应该比大数组上的 Numpy 实现(也不应该 运行 顺序)快一点(在小数组上,产生多个线程的成本可能比实际计算更大)。
我有一个示例代码可以说明我的问题。如果你 运行:
import numpy as np
from numba import jit
@jit(nopython=True)
def test():
arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])
arr2 = arr[:, 0, :]
arr3 = arr2.argsort()
print(arr3)
test()
它会失败:
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of BoundFunction(array.argsort for array(int64, 2d, A)) with parameters ()
During: resolving callee type: BoundFunction(array.argsort for array(int64, 2d, A))
During: typing of call at /home/stark/Work/mmr6/test.py (41)
File "test.py", line 41:
def test():
<source elided>
arr3 = arr2.argsort()
^
argsort 应该在最后一个轴上进行 argsort。基本上它应该给我:
>>>
[[0 1 2]
[0 1 2]]
我认为复制 arr2
数组(使用 copy()
)可以解决问题,因为它会使数组在内存中连续(而不是视图),但它失败并显示相同的消息,除了消息中 arr2
的类型现在是预期的 array(int64, 2d, C)
。
为什么会失败,我该如何解决?
遗憾的是,这是 Numba 当前已知的限制。参见 this issue。目前仅支持一维数组。但是,您的情况有一个简单的解决方法:
import numpy as np
from numba import jit
@jit(nopython=True)
def test():
arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])
arr2 = arr[:, 0, :]
arr3 = np.empty(arr2.shape, dtype=arr2.dtype)
for i in range(arr2.shape[0]):
arr3[i] = arr2[i, :].argsort()
print(arr3)
test()
请注意,即使实现了,也不会更快。参见 this issue。实际上,对于任何给定的 Numpy 基元,Numba 没有理由更快。但是,您可以使用 Numba 手动编写您自己的 Numpy 原语版本,并且有时会由于算法专业化、并行性或数学优化(例如快速数学)而获得加速。当您想执行 Numpy 中 yet/directly 不可用的高效操作时,Numba 通常很棒,并且可以使用循环轻松实现此操作。
实际上,您可以使用 Numba 的 prange
和 JIT 参数 parallel=True
来加快计算速度,假设 argsort
尚未 运行 并行(AFAIK 它应该是连续的)。这应该比大数组上的 Numpy 实现(也不应该 运行 顺序)快一点(在小数组上,产生多个线程的成本可能比实际计算更大)。