numba.jit 无法编译 np.roll

numba.jit can’t compile np.roll

我正在尝试使用 jit

编译 "foo" 函数
import numpy as np
from numba import jit

dy = 5
@jit
def foo(grid):
    return np.sum([np.roll(np.roll(grid, y, axis = 1), x, axis = 0)
                   for x in (-1, 0, 1) for y in (-1, 0, 1) if x or y], axis=0)


ex_grid = np.random.rand(5,5)>0.5
result = foo(ex_grid)

我收到以下错误:

Compilation is falling back to object mode WITH looplifting enabled because Function "foo" failed type inference due to: Invalid use of Function(<function roll at 0x00000161E45C7D90>) with argument(s) of type(s): (array(bool, 2d, C), Literal[int](5), axis=Literal[int](1))
 * parameterized
In definition 0:
    TypeError: np_roll() got an unexpected keyword argument 'axis'

功能正常,编译失败

如何修复此错误,np.roll 是否与 numba 兼容,如果不兼容,是否有替代方案?

如果你检查 docs 你会看到 np.roll 只支持前两个参数,因此它只会在展平数组上执行滚动(因为你不能指定一个轴)。

numpy.roll() (only the 2 first arguments; second argument shift must be an integer)

但是请注意,在这里使用 numba 并没有真正意义,因为您正在执行单个矢量化操作,这已经 运行 非常快了。只有当您必须遍历数组以应用某些逻辑时,Numba 才有意义。

因此,使用 numba roll 数组行的唯一可能方法是循环遍历它们:

@njit
def foo(a, dy):
    out = np.empty(a.shape, np.int32)
    for i in range(a.shape[0]):
        out[i] = np.roll(a[i], dy)
    return out

np.allclose(foo(ex_grid, 3).astype(bool), np.roll(ex_grid, 3, axis=1))
# True

尽管如前所述,这将比简单地使用 np.roll 设置 axis=1 慢得多,因为这已经矢量化并且所有循环都在 C 级别上完成:

ex_grid = np.random.rand(5000,5000)>0.5

%timeit foo(ex_grid, 3)
# 111 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit np.roll(ex_grid, 1, axis=1)
# 13.8 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)