使用 jit nopython 了解 Numba TypingError
Understanding Numba TypingError with jit nopython
我在使用 @jit(nopython=True)
解决(可能是基本的)Numba 错误时遇到问题。它归结为下面的最小示例,它产生一个 TypingError
(下面的完整日志)。如果相关,我正在使用 Python 3.6.10 和 Numba v0.49.0.
错误发生在创建 numpy 数组的 d
行(如果我删除 d
和 return c
,它工作正常)。我该如何解决这个问题?
from numba import jit
import numpy as np
n = 5
foo = np.random.rand(n,n)
@jit(nopython=True)
def bar(x):
a = np.array([0,3,2])
b = np.array([1,2,3])
c = [x[i,j] for i,j in zip(a,b)]
# print(c) # Un-commenting this line solves the issue‽ (per @Ethan's comment)
d = np.array(c)
return d
baz = bar(foo)
完整错误如下:
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-13-950d2be33d72> in <module>
14 return d
15
---> 16 baz = bar(foo)
17 print(baz)
~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/utils.py in reraise(tp, value, tb)
77 value = tp()
78 if value.__traceback__ is not tb:
---> 79 raise value.with_traceback(tb)
80 raise value
81
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<intrinsic range_iter_len>) with argument(s) of type(s): (zip(iter(array(int64, 1d, C)), iter(array(int64, 1d, C))))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<intrinsic range_iter_len>)
[2] During: typing of call at <ipython-input-13-950d2be33d72> (9)
File "<ipython-input-13-950d2be33d72>", line 9:
def bar(x):
a = np.array([0,3,2])
^
更新: 使用以下函数以类似的方式失败(尽管 print(c)
技巧在这种情况下没有帮助):
@jit(nopython=True)
def bar(x):
a = [0,3,2]
b = [1,2,3]
c = x[a, b]
d = np.array(c)
return d
函数第一个版本的问题,添加 print(c)
解决了这个问题,这对我来说是个谜。 Numba 应该实现 zip
(显然,在这种情况下,当 print(c)
行以某种方式触发时,它可以实现),所以这似乎是一个错误。
函数第二个版本的问题不那么神秘了。根据 current Numba documentation:
Arrays support normal iteration. Full basic indexing and slicing is supported. A subset of advanced indexing is also supported: only one advanced index is allowed, and it has to be a one-dimensional array (it can be combined with an arbitrary number of basic indices as well).
由于您尝试在 c = x[a, b]
行中使用两个高级索引 a
和 b
,Numba 不支持该代码。事实上,这就是冗长的错误消息 Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), tuple(array(int64, 1d, C) x 2))
所说的。
如果我们改写 c=x[a,2]
,那么代码会工作,这与 Numba 允许一个高级索引的承诺一致。
总的来说,我发现使用 Numba 最安全的方法是在没有 NumPy 更高级功能的情况下以循环方式编写。这有点不幸——因为这几乎就好像我们需要用 C 的方言而不是 Python 来编写——但从好的方面来说,它仍然比实际编写 C 方便得多。
在这种情况下,以下代码运行良好:
@jit(nopython=True)
def bar(x):
a = np.array([0,3,2])
b = np.array([1,2,3])
c = np.empty(len(a))
for i in range(len(a)):
c[i] = x[a[i], b[i]]
return c
我遇到了类似的问题,只是通过更新 numba 解决了这个问题:
pip install --upgrade numba
我在使用 @jit(nopython=True)
解决(可能是基本的)Numba 错误时遇到问题。它归结为下面的最小示例,它产生一个 TypingError
(下面的完整日志)。如果相关,我正在使用 Python 3.6.10 和 Numba v0.49.0.
错误发生在创建 numpy 数组的 d
行(如果我删除 d
和 return c
,它工作正常)。我该如何解决这个问题?
from numba import jit
import numpy as np
n = 5
foo = np.random.rand(n,n)
@jit(nopython=True)
def bar(x):
a = np.array([0,3,2])
b = np.array([1,2,3])
c = [x[i,j] for i,j in zip(a,b)]
# print(c) # Un-commenting this line solves the issue‽ (per @Ethan's comment)
d = np.array(c)
return d
baz = bar(foo)
完整错误如下:
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-13-950d2be33d72> in <module>
14 return d
15
---> 16 baz = bar(foo)
17 print(baz)
~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/utils.py in reraise(tp, value, tb)
77 value = tp()
78 if value.__traceback__ is not tb:
---> 79 raise value.with_traceback(tb)
80 raise value
81
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<intrinsic range_iter_len>) with argument(s) of type(s): (zip(iter(array(int64, 1d, C)), iter(array(int64, 1d, C))))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<intrinsic range_iter_len>)
[2] During: typing of call at <ipython-input-13-950d2be33d72> (9)
File "<ipython-input-13-950d2be33d72>", line 9:
def bar(x):
a = np.array([0,3,2])
^
更新: 使用以下函数以类似的方式失败(尽管 print(c)
技巧在这种情况下没有帮助):
@jit(nopython=True)
def bar(x):
a = [0,3,2]
b = [1,2,3]
c = x[a, b]
d = np.array(c)
return d
函数第一个版本的问题,添加 print(c)
解决了这个问题,这对我来说是个谜。 Numba 应该实现 zip
(显然,在这种情况下,当 print(c)
行以某种方式触发时,它可以实现),所以这似乎是一个错误。
函数第二个版本的问题不那么神秘了。根据 current Numba documentation:
Arrays support normal iteration. Full basic indexing and slicing is supported. A subset of advanced indexing is also supported: only one advanced index is allowed, and it has to be a one-dimensional array (it can be combined with an arbitrary number of basic indices as well).
由于您尝试在 c = x[a, b]
行中使用两个高级索引 a
和 b
,Numba 不支持该代码。事实上,这就是冗长的错误消息 Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), tuple(array(int64, 1d, C) x 2))
所说的。
如果我们改写 c=x[a,2]
,那么代码会工作,这与 Numba 允许一个高级索引的承诺一致。
总的来说,我发现使用 Numba 最安全的方法是在没有 NumPy 更高级功能的情况下以循环方式编写。这有点不幸——因为这几乎就好像我们需要用 C 的方言而不是 Python 来编写——但从好的方面来说,它仍然比实际编写 C 方便得多。
在这种情况下,以下代码运行良好:
@jit(nopython=True)
def bar(x):
a = np.array([0,3,2])
b = np.array([1,2,3])
c = np.empty(len(a))
for i in range(len(a)):
c[i] = x[a[i], b[i]]
return c
我遇到了类似的问题,只是通过更新 numba 解决了这个问题:
pip install --upgrade numba