索引 numpy 数组时 numba @njit 出错
Error in numba @njit when indexing numpy array
我正在尝试使用 numba 构建一个函数,该函数 returns 一个 numpy 数组在另一个数组上求值
我将 post 没有 njit 的简单代码:
import numpy as np
import numba as nb
def prueba(arr, eva):
mask = []
for i in range(len(arr)):
mask.append(arr[i])
return eva[mask]
如预期的那样正常工作:
>>> prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
array([6, 7, 8])
然而,当我尝试在 nopython 模式下使用 numba 编译它时 (@njit) 它会抛出一个错误
@nb.njit
def prueba(arr, eva):
mask = []
for i in range(len(arr)):
mask.append(arr[i])
return eva[mask]
>>> prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-9-111474f08921> in <module>
----> 1 prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
~/.local/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
~/.local/lib/python3.7/site-packages/numba/dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
~/.local/lib/python3.7/site-packages/numba/six.py in reraise(tp, value, tb)
666 value = tp()
667 if value.__traceback__ is not tb:
--> 668 raise value.with_traceback(tb)
669 raise value
670
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int64, 1d, C), list(int64))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
All templates rejected with literals.
In definition 9:
All templates rejected without literals.
In definition 10:
All templates rejected with literals.
In definition 11:
All templates rejected without literals.
In definition 12:
TypeError: unsupported array index type list(int64) in [list(int64)]
raised from /home/donielix/.local/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
In definition 13:
TypeError: unsupported array index type list(int64) in [list(int64)]
raised from /home/donielix/.local/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at <ipython-input-8-1b5c9f1a65d5> (6)
[2] During: typing of static-get-item at <ipython-input-8-1b5c9f1a65d5> (6)
File "<ipython-input-8-1b5c9f1a65d5>", line 6:
def prueba(arr, eva):
<source elided>
mask.append(arr[i])
return eva[mask]
^
所以我的问题是,为什么这个简单的代码会出现意外错误?我应该如何解决这个问题?
您的索引使用 numpy
:
In [181]: a, b = np.array([1,2,3]), np.array([5,6,7,8,9,10])
In [182]: b[a]
Out[182]: array([6, 7, 8])
In [183]: def foo(arr, eva):
...: return eva[arr]
...:
In [184]: foo(a,b)
Out[184]: array([6, 7, 8])
In [186]: timeit foo(a,b)
350 ns ± 9.98 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
尝试用 numba
复制它(并可能加快速度):
In [185]: import numba
In [187]: @numba.njit
...: def foo1(arr,eva):
...: return eva[arr]
...:
In [188]: foo1(a,b)
Out[188]: array([6, 7, 8])
In [189]: timeit foo1(a,b)
968 ns ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [190]: @numba.njit
...: def foo2(arr,eva):
...: res = np.empty(len(arr), eva.dtype)
...: for i in range(len(arr)):
...: res[i] = b[a[i]]
...: return res
In [191]: foo2(a,b)
Out[191]: array([6, 7, 8])
In [192]: timeit foo2(a,b)
941 ns ± 7.91 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [193]: @numba.njit
...: def foo2(arr,eva):
...: res = np.empty(len(arr), eva.dtype)
...: for i,v in enumerate(a):
...: res[i] = b[v]
...: return res
In [194]: foo2(a,b)
Out[194]: array([6, 7, 8])
In [195]: timeit foo2(a,b)
941 ns ± 17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
尝试用 numba
替换基本 numpy
功能没有多大意义。
有更多 numba
经验的人可能会对此有所改进。
编辑
正如我最初观察到的,numba
不喜欢使用列表进行索引。将列表转换为数组有效:
In [196]: @numba.njit
...: def prueba(arr, eva):
...: mask = []
...: for i in range(len(arr)):
...: mask.append(arr[i])
...: mask = np.array(mask)
...: return eva[mask]
...:
In [197]: prueba(a,b)
Out[197]: array([6, 7, 8])
In [198]: timeit prueba(a,b)
1.5 µs ± 4.79 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
直接来自文档:
A subset of advanced indexing is also supported: only one advanced
index is allowed, and it has to be a one-dimensional array (it can be
combined with an arbitrary number of basic indices as well).
https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#array-access
因此,为了使您的代码正常工作,您必须将 mask
转换为 numpy array
:
@nb.njit
def prueba(arr, eva):
mask = []
for i in range(len(arr)):
mask.append(arr[i])
mask_as_array = np.array(mask)
return eva[mask_as_array]
prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
我正在尝试使用 numba 构建一个函数,该函数 returns 一个 numpy 数组在另一个数组上求值 我将 post 没有 njit 的简单代码:
import numpy as np
import numba as nb
def prueba(arr, eva):
mask = []
for i in range(len(arr)):
mask.append(arr[i])
return eva[mask]
如预期的那样正常工作:
>>> prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
array([6, 7, 8])
然而,当我尝试在 nopython 模式下使用 numba 编译它时 (@njit) 它会抛出一个错误
@nb.njit
def prueba(arr, eva):
mask = []
for i in range(len(arr)):
mask.append(arr[i])
return eva[mask]
>>> prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-9-111474f08921> in <module>
----> 1 prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
~/.local/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
~/.local/lib/python3.7/site-packages/numba/dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
~/.local/lib/python3.7/site-packages/numba/six.py in reraise(tp, value, tb)
666 value = tp()
667 if value.__traceback__ is not tb:
--> 668 raise value.with_traceback(tb)
669 raise value
670
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int64, 1d, C), list(int64))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
All templates rejected with literals.
In definition 9:
All templates rejected without literals.
In definition 10:
All templates rejected with literals.
In definition 11:
All templates rejected without literals.
In definition 12:
TypeError: unsupported array index type list(int64) in [list(int64)]
raised from /home/donielix/.local/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
In definition 13:
TypeError: unsupported array index type list(int64) in [list(int64)]
raised from /home/donielix/.local/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at <ipython-input-8-1b5c9f1a65d5> (6)
[2] During: typing of static-get-item at <ipython-input-8-1b5c9f1a65d5> (6)
File "<ipython-input-8-1b5c9f1a65d5>", line 6:
def prueba(arr, eva):
<source elided>
mask.append(arr[i])
return eva[mask]
^
所以我的问题是,为什么这个简单的代码会出现意外错误?我应该如何解决这个问题?
您的索引使用 numpy
:
In [181]: a, b = np.array([1,2,3]), np.array([5,6,7,8,9,10])
In [182]: b[a]
Out[182]: array([6, 7, 8])
In [183]: def foo(arr, eva):
...: return eva[arr]
...:
In [184]: foo(a,b)
Out[184]: array([6, 7, 8])
In [186]: timeit foo(a,b)
350 ns ± 9.98 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
尝试用 numba
复制它(并可能加快速度):
In [185]: import numba
In [187]: @numba.njit
...: def foo1(arr,eva):
...: return eva[arr]
...:
In [188]: foo1(a,b)
Out[188]: array([6, 7, 8])
In [189]: timeit foo1(a,b)
968 ns ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [190]: @numba.njit
...: def foo2(arr,eva):
...: res = np.empty(len(arr), eva.dtype)
...: for i in range(len(arr)):
...: res[i] = b[a[i]]
...: return res
In [191]: foo2(a,b)
Out[191]: array([6, 7, 8])
In [192]: timeit foo2(a,b)
941 ns ± 7.91 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [193]: @numba.njit
...: def foo2(arr,eva):
...: res = np.empty(len(arr), eva.dtype)
...: for i,v in enumerate(a):
...: res[i] = b[v]
...: return res
In [194]: foo2(a,b)
Out[194]: array([6, 7, 8])
In [195]: timeit foo2(a,b)
941 ns ± 17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
尝试用 numba
替换基本 numpy
功能没有多大意义。
有更多 numba
经验的人可能会对此有所改进。
编辑
正如我最初观察到的,numba
不喜欢使用列表进行索引。将列表转换为数组有效:
In [196]: @numba.njit
...: def prueba(arr, eva):
...: mask = []
...: for i in range(len(arr)):
...: mask.append(arr[i])
...: mask = np.array(mask)
...: return eva[mask]
...:
In [197]: prueba(a,b)
Out[197]: array([6, 7, 8])
In [198]: timeit prueba(a,b)
1.5 µs ± 4.79 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
直接来自文档:
A subset of advanced indexing is also supported: only one advanced index is allowed, and it has to be a one-dimensional array (it can be combined with an arbitrary number of basic indices as well). https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#array-access
因此,为了使您的代码正常工作,您必须将 mask
转换为 numpy array
:
@nb.njit
def prueba(arr, eva):
mask = []
for i in range(len(arr)):
mask.append(arr[i])
mask_as_array = np.array(mask)
return eva[mask_as_array]
prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))