numba.jit 无法编译 np.roll
numba.jit can’t compile np.roll
我正在尝试使用 jit
编译 "foo" 函数
import numpy as np
from numba import jit
dy = 5
@jit
def foo(grid):
return np.sum([np.roll(np.roll(grid, y, axis = 1), x, axis = 0)
for x in (-1, 0, 1) for y in (-1, 0, 1) if x or y], axis=0)
ex_grid = np.random.rand(5,5)>0.5
result = foo(ex_grid)
我收到以下错误:
Compilation is falling back to object mode WITH looplifting enabled because Function "foo" failed type inference due to: Invalid use of Function(<function roll at 0x00000161E45C7D90>) with argument(s) of type(s): (array(bool, 2d, C), Literal[int](5), axis=Literal[int](1))
* parameterized
In definition 0:
TypeError: np_roll() got an unexpected keyword argument 'axis'
功能正常,编译失败
如何修复此错误,np.roll 是否与 numba 兼容,如果不兼容,是否有替代方案?
如果你检查 docs 你会看到 np.roll
只支持前两个参数,因此它只会在展平数组上执行滚动(因为你不能指定一个轴)。
numpy.roll() (only the 2 first arguments; second argument shift must be an integer)
但是请注意,在这里使用 numba 并没有真正意义,因为您正在执行单个矢量化操作,这已经 运行 非常快了。只有当您必须遍历数组以应用某些逻辑时,Numba 才有意义。
因此,使用 numba roll
数组行的唯一可能方法是循环遍历它们:
@njit
def foo(a, dy):
out = np.empty(a.shape, np.int32)
for i in range(a.shape[0]):
out[i] = np.roll(a[i], dy)
return out
np.allclose(foo(ex_grid, 3).astype(bool), np.roll(ex_grid, 3, axis=1))
# True
尽管如前所述,这将比简单地使用 np.roll
设置 axis=1
慢得多,因为这已经矢量化并且所有循环都在 C
级别上完成:
ex_grid = np.random.rand(5000,5000)>0.5
%timeit foo(ex_grid, 3)
# 111 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.roll(ex_grid, 1, axis=1)
# 13.8 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
我正在尝试使用 jit
编译 "foo" 函数import numpy as np
from numba import jit
dy = 5
@jit
def foo(grid):
return np.sum([np.roll(np.roll(grid, y, axis = 1), x, axis = 0)
for x in (-1, 0, 1) for y in (-1, 0, 1) if x or y], axis=0)
ex_grid = np.random.rand(5,5)>0.5
result = foo(ex_grid)
我收到以下错误:
Compilation is falling back to object mode WITH looplifting enabled because Function "foo" failed type inference due to: Invalid use of Function(<function roll at 0x00000161E45C7D90>) with argument(s) of type(s): (array(bool, 2d, C), Literal[int](5), axis=Literal[int](1))
* parameterized
In definition 0:
TypeError: np_roll() got an unexpected keyword argument 'axis'
功能正常,编译失败
如何修复此错误,np.roll 是否与 numba 兼容,如果不兼容,是否有替代方案?
如果你检查 docs 你会看到 np.roll
只支持前两个参数,因此它只会在展平数组上执行滚动(因为你不能指定一个轴)。
numpy.roll() (only the 2 first arguments; second argument shift must be an integer)
但是请注意,在这里使用 numba 并没有真正意义,因为您正在执行单个矢量化操作,这已经 运行 非常快了。只有当您必须遍历数组以应用某些逻辑时,Numba 才有意义。
因此,使用 numba roll
数组行的唯一可能方法是循环遍历它们:
@njit
def foo(a, dy):
out = np.empty(a.shape, np.int32)
for i in range(a.shape[0]):
out[i] = np.roll(a[i], dy)
return out
np.allclose(foo(ex_grid, 3).astype(bool), np.roll(ex_grid, 3, axis=1))
# True
尽管如前所述,这将比简单地使用 np.roll
设置 axis=1
慢得多,因为这已经矢量化并且所有循环都在 C
级别上完成:
ex_grid = np.random.rand(5000,5000)>0.5
%timeit foo(ex_grid, 3)
# 111 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.roll(ex_grid, 1, axis=1)
# 13.8 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)