njit numba 函数的高级索引替代方案
Advanced indexing alternative for njit numba function
给定以下最小可重现示例:
import numpy as np
from numba import jit
# variable number of dimensions
n_t = 8
# q is just a partition of n
q_ddl = 2
n_ddl = 3
np.random.seed(42)
df = np.random.rand(q_ddl*n_t,q_ddl*n_t)
# index array
# ddl_nl is a set of np.arange(n_ddl), ex: [0,1] ; [0,2] or even [0] ...
ddl_nl = np.array([0,1])
ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij'))
@jit(nopython=True)
def foo(df,ij):
out = np.zeros((n_t,n_ddl,n_ddl))
for i in range(0,n_t):
d_i = np.zeros((n_ddl,n_ddl))
# (q_ddl,q_ddl) non zero values into (n_ddl,n_ddl) shape
d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]
# to check possible solutions
out[i,...] = d_i
return out
out_foo = foo(df,ij)
函数 foo
在禁用 @jit(nopython=True)
时运行良好,但在启用时抛出以下错误:
TypeError: unsupported array index type array(int64, 2d, C) in UniTuple(array(int64, 2d, C) x 2)
在广播操作期间发生 d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]
。然后,我确实尝试用 d_i[ij[0].ravel(), ij[1].ravel()] = df[i::n_t,i::n_t].ravel()
之类的东西将二维索引数组 ij
展平,这给了我相同的输出,但现在又出现了另一个错误:
NotImplementedError: only one advanced index supported
所以我最终尝试通过使用经典的 2 嵌套 for
循环结构来避免这种情况:
tmp = df[i::n_t,i::n_t]
for k,r in enumerate(ddl_nl):
for l,c in enumerate(ddl_nl):
d_i[r,c] = tmp[k,l]
正在启用装饰器并给出预期结果。
但我一直在想,对于我在这里缺少的这个 numpy 二维数组广播操作,是否有任何与 numba 兼容的替代方案?任何帮助将不胜感激。
正在检查您的一些价值观:
In [446]: ddl_nl = np.array([0,1])
...: ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij'))
...:
In [447]: ij
Out[447]:
array([[[0, 0],
[1, 1]],
[[0, 1],
[0, 1]]])
In [448]: n_t = 8
...: q_ddl = 2
...: n_ddl = 3
In [449]: d_i = np.zeros((n_ddl,n_ddl))
In [450]: d_i
Out[450]:
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
In [451]: d_i[ij[0], ij[1]]
Out[451]:
array([[0., 0.],
[0., 0.]])
尝试更准确的诊断 d_i
:
In [452]: d_i = np.arange(9).reshape(3,3)
In [453]: d_i[ij[0], ij[1]]
Out[453]:
array([[0, 1],
[3, 4]])
In [454]: d_i[:2,:2]
Out[454]:
array([[0, 1],
[3, 4]])
当基本切片可以工作时,为什么要使用高级索引?
我还没有用 numba
尝试过这个,但它可能有更好的工作机会。也就是说,枚举循环可能同样快。我对 numba
没有足够的经验可以肯定地说。
===
显然您遇到了 numba
不支持的 numpy
操作:
In [456]: numba.__version__
Out[456]: '0.43.0'
In [457]: @numba.jit
...: def foo(arr):
...: return arr[[1,2,3],[1,2,3]]
...:
In [458]: foo(np.eye(4))
Out[458]: array([1., 1., 1.])
In [459]: @numba.njit
...: def foo(arr):
...: return arr[[1,2,3],[1,2,3]]
...:
In [460]: foo(np.eye(4))
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), tuple(list(int64) x 2))
这并不罕见。 numba
并未声称完全覆盖了 Python
或 numpy
。
但是使用 numba
我们不必避免迭代。事实上,它在替换 numpy
没有迭代就无法完成的操作时处于最佳状态。
In [465]: @numba.njit
...: def foo(arr):
...: out = np.zeros((3,), arr.dtype)
...: for n, (i,j) in enumerate(zip([1,2,3],[1,2,3])):
...: out[n] = arr[i,j]
...: return out
In [466]: foo(np.eye(4))
Out[466]: array([1., 1., 1.])
In [467]: timeit foo(np.eye(4))
6.85 µs ± 28.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [468]: np.eye(4)[[1,2,3],[1,2,3]]
Out[468]: array([1., 1., 1.])
In [469]: timeit np.eye(4)[[1,2,3],[1,2,3]]
13.3 µs ± 31.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
避免花哨的索引
还要避免使用全局变量(它们在编译时被硬编码)并使您的代码尽可能简单(简单意味着只有 dew 循环,if/else,...)。如果 ddl_nl
数组真的只用 np.arange 构造,甚至根本不需要这个数组。
例子
import numpy as np
from numba import jit
@jit(nopython=True)
def foo_nb(df,n_ddl,n_t,ddl_nl):
out = np.zeros((n_t,n_ddl,n_ddl))
for i in range(0,n_t):
for ii in range(ddl_nl.shape[0]):
ind_1=ddl_nl[ii]
for jj in range(ddl_nl.shape[0]):
ind_2=ddl_nl[jj]
out[i,ind_1,ind_2] = df[i+ii*n_t,i+jj*n_t]
return out
时间
#Testing and compilation
A=foo(df,ij)
B=foo_nb(df,n_ddl,n_t,ddl_nl)
print(np.allclose(A,B))
#True
%timeit foo(df,ij)
#16.8 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_nb(df,n_ddl,n_t,ddl_nl)
#674 ns ± 2.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
给定以下最小可重现示例:
import numpy as np
from numba import jit
# variable number of dimensions
n_t = 8
# q is just a partition of n
q_ddl = 2
n_ddl = 3
np.random.seed(42)
df = np.random.rand(q_ddl*n_t,q_ddl*n_t)
# index array
# ddl_nl is a set of np.arange(n_ddl), ex: [0,1] ; [0,2] or even [0] ...
ddl_nl = np.array([0,1])
ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij'))
@jit(nopython=True)
def foo(df,ij):
out = np.zeros((n_t,n_ddl,n_ddl))
for i in range(0,n_t):
d_i = np.zeros((n_ddl,n_ddl))
# (q_ddl,q_ddl) non zero values into (n_ddl,n_ddl) shape
d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]
# to check possible solutions
out[i,...] = d_i
return out
out_foo = foo(df,ij)
函数 foo
在禁用 @jit(nopython=True)
时运行良好,但在启用时抛出以下错误:
TypeError: unsupported array index type array(int64, 2d, C) in UniTuple(array(int64, 2d, C) x 2)
在广播操作期间发生 d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]
。然后,我确实尝试用 d_i[ij[0].ravel(), ij[1].ravel()] = df[i::n_t,i::n_t].ravel()
之类的东西将二维索引数组 ij
展平,这给了我相同的输出,但现在又出现了另一个错误:
NotImplementedError: only one advanced index supported
所以我最终尝试通过使用经典的 2 嵌套 for
循环结构来避免这种情况:
tmp = df[i::n_t,i::n_t]
for k,r in enumerate(ddl_nl):
for l,c in enumerate(ddl_nl):
d_i[r,c] = tmp[k,l]
正在启用装饰器并给出预期结果。
但我一直在想,对于我在这里缺少的这个 numpy 二维数组广播操作,是否有任何与 numba 兼容的替代方案?任何帮助将不胜感激。
正在检查您的一些价值观:
In [446]: ddl_nl = np.array([0,1])
...: ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij'))
...:
In [447]: ij
Out[447]:
array([[[0, 0],
[1, 1]],
[[0, 1],
[0, 1]]])
In [448]: n_t = 8
...: q_ddl = 2
...: n_ddl = 3
In [449]: d_i = np.zeros((n_ddl,n_ddl))
In [450]: d_i
Out[450]:
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
In [451]: d_i[ij[0], ij[1]]
Out[451]:
array([[0., 0.],
[0., 0.]])
尝试更准确的诊断 d_i
:
In [452]: d_i = np.arange(9).reshape(3,3)
In [453]: d_i[ij[0], ij[1]]
Out[453]:
array([[0, 1],
[3, 4]])
In [454]: d_i[:2,:2]
Out[454]:
array([[0, 1],
[3, 4]])
当基本切片可以工作时,为什么要使用高级索引?
我还没有用 numba
尝试过这个,但它可能有更好的工作机会。也就是说,枚举循环可能同样快。我对 numba
没有足够的经验可以肯定地说。
===
显然您遇到了 numba
不支持的 numpy
操作:
In [456]: numba.__version__
Out[456]: '0.43.0'
In [457]: @numba.jit
...: def foo(arr):
...: return arr[[1,2,3],[1,2,3]]
...:
In [458]: foo(np.eye(4))
Out[458]: array([1., 1., 1.])
In [459]: @numba.njit
...: def foo(arr):
...: return arr[[1,2,3],[1,2,3]]
...:
In [460]: foo(np.eye(4))
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), tuple(list(int64) x 2))
这并不罕见。 numba
并未声称完全覆盖了 Python
或 numpy
。
但是使用 numba
我们不必避免迭代。事实上,它在替换 numpy
没有迭代就无法完成的操作时处于最佳状态。
In [465]: @numba.njit
...: def foo(arr):
...: out = np.zeros((3,), arr.dtype)
...: for n, (i,j) in enumerate(zip([1,2,3],[1,2,3])):
...: out[n] = arr[i,j]
...: return out
In [466]: foo(np.eye(4))
Out[466]: array([1., 1., 1.])
In [467]: timeit foo(np.eye(4))
6.85 µs ± 28.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [468]: np.eye(4)[[1,2,3],[1,2,3]]
Out[468]: array([1., 1., 1.])
In [469]: timeit np.eye(4)[[1,2,3],[1,2,3]]
13.3 µs ± 31.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
避免花哨的索引
还要避免使用全局变量(它们在编译时被硬编码)并使您的代码尽可能简单(简单意味着只有 dew 循环,if/else,...)。如果 ddl_nl
数组真的只用 np.arange 构造,甚至根本不需要这个数组。
例子
import numpy as np
from numba import jit
@jit(nopython=True)
def foo_nb(df,n_ddl,n_t,ddl_nl):
out = np.zeros((n_t,n_ddl,n_ddl))
for i in range(0,n_t):
for ii in range(ddl_nl.shape[0]):
ind_1=ddl_nl[ii]
for jj in range(ddl_nl.shape[0]):
ind_2=ddl_nl[jj]
out[i,ind_1,ind_2] = df[i+ii*n_t,i+jj*n_t]
return out
时间
#Testing and compilation
A=foo(df,ij)
B=foo_nb(df,n_ddl,n_t,ddl_nl)
print(np.allclose(A,B))
#True
%timeit foo(df,ij)
#16.8 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_nb(df,n_ddl,n_t,ddl_nl)
#674 ns ± 2.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)