Numba jit 的问题:"Typing error" 和 "All templates rejected with/without literals"
Problem with Numba jit: "Typing error" and "All templates rejected with/without literals"
我正在实施一个程序来求解 Python 3.7.3 中的微分方程,并且有一个函数我无法使用 Numba 进行编译。它的最新版本是:
import numpy as np
from numba import jit, uint16, complex128, prange
# Here is the setup of the program, as well as variable initialization
@jit((complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
def upd_x(rhs: np.ndarray, m: int, s: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
x = np.zeros((3, m, m//2+1))
x[2] = s*(1-a*(rhs[0]+rhs[1]))
for i in range(2):
x[i] = a*(rhs[i]+b*x[2])
return x
它应该做的是,取方程的 "right hand side" (rhs
) 并更新 x
(x
有 3 个分量,它们是实数字段, 并且代码在傅里叶 space 中是 "updating" ,这就是为什么最后一个轴是 m//2+1
而不是 m
) 与 Schur 的补码方法。当我 运行 代码时,我收到以下消息:
Traceback (most recent call last):
File "C:/Users/Username/Desktop/Program/Program.py", line 95, in <module>
@jit((complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
File "C:\Program Files\Python37\lib\site-packages\numba\decorators.py", line 186, in wrapper
disp.compile(sig)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "C:\Program Files\Python37\lib\site-packages\numba\dispatcher.py", line 659, in compile
cres = self._compiler.compile(args, return_type)
File "C:\Program Files\Python37\lib\site-packages\numba\dispatcher.py", line 83, in compile
pipeline_class=self.pipeline_class)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 955, in compile_extra
return pipeline.compile_extra(func)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 377, in compile_extra
return self._compile_bytecode()
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 886, in _compile_bytecode
return self._compile_core()
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 873, in _compile_core
res = pm.run(self.status)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 254, in run
raise patched_exception
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 245, in run
stage()
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 501, in stage_nopython_frontend
self.locals)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 1105, in type_inference_stage
infer.propagate()
File "C:\Program Files\Python37\lib\site-packages\numba\typeinfer.py", line 915, in propagate
raise errors[0]
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 3d, C), Literal[int](2), array(complex128, 2d, C))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at C:/Users/User/Desktop/Program/Program.py (98)
File "Programa.py", line 98:
def upd_x(rhs, m, s, a, b):
<source elided>
x = np.zeros((3, m, m//2+1))
x[2] = s*(1-a*(rhs[0]+rhs[1]))
^
不明白为什么报错提示不支持变量类型,也不知道有什么问题需要更正。我使用的版本是 numba==0.44.1, numpy==1.16.1.
非常感谢。
看起来 Numba 无法确定输出 x
的类型,所以我在 x
中添加了 dtype
。然后你 运行 混合 np.int64
和 uint16
,在 np.zeros
的大小参数中,因为 3
被解释为 i64
。所以以下将编译:
import numpy as np
from numba import jit, uint16, complex128, prange
# Here is the setup of the program, as well as variable initialization
@jit(complex128[:,:,:](complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
def upd_x(rhs: np.ndarray, m: int, s: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
mx = np.int64(m)
x = np.zeros((3, mx, mx//2+1), dtype=np.complex128)
x[2] = s*(1-a*(rhs[0]+rhs[1]))
for i in range(2):
x[i] = a*(rhs[i]+b*x[2])
return x
此外,请注意,我在传递给 @jit
的签名中添加了 return 类型,但我认为这不是必需的。
所以我使用输入:
m = 4
x = np.zeros((3, m, m//2+1), dtype=np.complex128) + 2 + 2j
y = np.zeros((m, m//2 + 1 ), dtype=np.complex128) + 1 + 1j
upd_x(x, np.uint16(m), y, y, y)
我认为这回馈了一些明智的东西。
我正在实施一个程序来求解 Python 3.7.3 中的微分方程,并且有一个函数我无法使用 Numba 进行编译。它的最新版本是:
import numpy as np
from numba import jit, uint16, complex128, prange
# Here is the setup of the program, as well as variable initialization
@jit((complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
def upd_x(rhs: np.ndarray, m: int, s: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
x = np.zeros((3, m, m//2+1))
x[2] = s*(1-a*(rhs[0]+rhs[1]))
for i in range(2):
x[i] = a*(rhs[i]+b*x[2])
return x
它应该做的是,取方程的 "right hand side" (rhs
) 并更新 x
(x
有 3 个分量,它们是实数字段, 并且代码在傅里叶 space 中是 "updating" ,这就是为什么最后一个轴是 m//2+1
而不是 m
) 与 Schur 的补码方法。当我 运行 代码时,我收到以下消息:
Traceback (most recent call last):
File "C:/Users/Username/Desktop/Program/Program.py", line 95, in <module>
@jit((complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
File "C:\Program Files\Python37\lib\site-packages\numba\decorators.py", line 186, in wrapper
disp.compile(sig)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "C:\Program Files\Python37\lib\site-packages\numba\dispatcher.py", line 659, in compile
cres = self._compiler.compile(args, return_type)
File "C:\Program Files\Python37\lib\site-packages\numba\dispatcher.py", line 83, in compile
pipeline_class=self.pipeline_class)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 955, in compile_extra
return pipeline.compile_extra(func)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 377, in compile_extra
return self._compile_bytecode()
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 886, in _compile_bytecode
return self._compile_core()
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 873, in _compile_core
res = pm.run(self.status)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 254, in run
raise patched_exception
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 245, in run
stage()
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 501, in stage_nopython_frontend
self.locals)
File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 1105, in type_inference_stage
infer.propagate()
File "C:\Program Files\Python37\lib\site-packages\numba\typeinfer.py", line 915, in propagate
raise errors[0]
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 3d, C), Literal[int](2), array(complex128, 2d, C))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at C:/Users/User/Desktop/Program/Program.py (98)
File "Programa.py", line 98:
def upd_x(rhs, m, s, a, b):
<source elided>
x = np.zeros((3, m, m//2+1))
x[2] = s*(1-a*(rhs[0]+rhs[1]))
^
不明白为什么报错提示不支持变量类型,也不知道有什么问题需要更正。我使用的版本是 numba==0.44.1, numpy==1.16.1.
非常感谢。
看起来 Numba 无法确定输出 x
的类型,所以我在 x
中添加了 dtype
。然后你 运行 混合 np.int64
和 uint16
,在 np.zeros
的大小参数中,因为 3
被解释为 i64
。所以以下将编译:
import numpy as np
from numba import jit, uint16, complex128, prange
# Here is the setup of the program, as well as variable initialization
@jit(complex128[:,:,:](complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
def upd_x(rhs: np.ndarray, m: int, s: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
mx = np.int64(m)
x = np.zeros((3, mx, mx//2+1), dtype=np.complex128)
x[2] = s*(1-a*(rhs[0]+rhs[1]))
for i in range(2):
x[i] = a*(rhs[i]+b*x[2])
return x
此外,请注意,我在传递给 @jit
的签名中添加了 return 类型,但我认为这不是必需的。
所以我使用输入:
m = 4
x = np.zeros((3, m, m//2+1), dtype=np.complex128) + 2 + 2j
y = np.zeros((m, m//2 + 1 ), dtype=np.complex128) + 1 + 1j
upd_x(x, np.uint16(m), y, y, y)
我认为这回馈了一些明智的东西。