创建一个 Numba 类型的字典,其中整数作为键,float64 数组作为值
Create a Numba typed dictionary with integers as keys and arrays of float64 as values
我需要定义一个字典,其中整数作为键,float64 数组作为值。在 Python 我可以定义它:
import numpy as np
d = {3: np.array([0, 1, 2, 3, 4])}
为了在 Numba 编译的函数中创建相同类型的字典,我这样做了
import numba
@numba.njit()
def generate_d():
d = Dict.empty(types.int64, types.float64[:])
return d
但我在编译时遇到错误。
鉴于非常简单的说明,我不明白为什么会出错。
这是我运行generate_d()
:
时的错误
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
/tmp/ipykernel_536115/3907784652.py in <module>
----> 1 generate_d()
~/envs/oasis/lib/python3.8/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
466 e.patch_message(msg)
467
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
~/envs/oasis/lib/python3.8/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
410
411 argtypes = []
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
>>> getitem(class(float64), slice<a:b>)
There are 22 candidate implementations:
- Of which 22 did not match due to:
Overload of function 'getitem': File: <numerous>: Line N/A.
With argument(s): '(class(float64), slice<a:b>)':
No match.
During: typing of intrinsic-call at /tmp/ipykernel_536115/3046996983.py (4)
During: typing of static-get-item at /tmp/ipykernel_536115/3046996983.py (4)
File "../../../../tmp/ipykernel_536115/3046996983.py", line 4:
<source missing, REPL/exec in use?>
即使我显式签名,我也会得到同样的错误
@numba.njit("float64[:]()")
def generate_d():
d = Dict.empty(types.int64, types.float64[:])
return d
我使用的是 numba v 0.55.1、numpy 1.20.3
我怎样才能让它工作?
据我所知,JIT 函数尚不支持类型表达式(Numba 版本 0.54.1)。您需要在函数外部创建类型。这是一个例子:
import numba
from numba.typed import Dict
# Type defined outside the JIT function
FloatArrayType = numba.types.float64[:]
@numba.njit
def generate_d():
d = Dict.empty(numba.types.int64, FloatArrayType) # <-- and used here
return d
我需要定义一个字典,其中整数作为键,float64 数组作为值。在 Python 我可以定义它:
import numpy as np
d = {3: np.array([0, 1, 2, 3, 4])}
为了在 Numba 编译的函数中创建相同类型的字典,我这样做了
import numba
@numba.njit()
def generate_d():
d = Dict.empty(types.int64, types.float64[:])
return d
但我在编译时遇到错误。 鉴于非常简单的说明,我不明白为什么会出错。
这是我运行generate_d()
:
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
/tmp/ipykernel_536115/3907784652.py in <module>
----> 1 generate_d()
~/envs/oasis/lib/python3.8/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
466 e.patch_message(msg)
467
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
~/envs/oasis/lib/python3.8/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
410
411 argtypes = []
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
>>> getitem(class(float64), slice<a:b>)
There are 22 candidate implementations:
- Of which 22 did not match due to:
Overload of function 'getitem': File: <numerous>: Line N/A.
With argument(s): '(class(float64), slice<a:b>)':
No match.
During: typing of intrinsic-call at /tmp/ipykernel_536115/3046996983.py (4)
During: typing of static-get-item at /tmp/ipykernel_536115/3046996983.py (4)
File "../../../../tmp/ipykernel_536115/3046996983.py", line 4:
<source missing, REPL/exec in use?>
即使我显式签名,我也会得到同样的错误
@numba.njit("float64[:]()")
def generate_d():
d = Dict.empty(types.int64, types.float64[:])
return d
我使用的是 numba v 0.55.1、numpy 1.20.3
我怎样才能让它工作?
据我所知,JIT 函数尚不支持类型表达式(Numba 版本 0.54.1)。您需要在函数外部创建类型。这是一个例子:
import numba
from numba.typed import Dict
# Type defined outside the JIT function
FloatArrayType = numba.types.float64[:]
@numba.njit
def generate_d():
d = Dict.empty(numba.types.int64, FloatArrayType) # <-- and used here
return d