创建一个 Numba 类型的字典,其中整数作为键,float64 数组作为值

Create a Numba typed dictionary with integers as keys and arrays of float64 as values

我需要定义一个字典,其中整数作为键,float64 数组作为值。在 Python 我可以定义它:

import numpy as np

d = {3: np.array([0, 1, 2, 3, 4])}

为了在 Numba 编译的函数中创建相同类型的字典,我这样做了

import numba

@numba.njit()
def generate_d():

    d = Dict.empty(types.int64, types.float64[:])

    return d

但我在编译时遇到错误。 鉴于非常简单的说明,我不明白为什么会出错。

这是我运行generate_d():

时的错误
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
/tmp/ipykernel_536115/3907784652.py in <module>
----> 1 generate_d()

~/envs/oasis/lib/python3.8/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    466                 e.patch_message(msg)
    467 
--> 468             error_rewrite(e, 'typing')
    469         except errors.UnsupportedError as e:
    470             # Something unsupported is present in the user code, add help info

~/envs/oasis/lib/python3.8/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
    407                 raise e
    408             else:
--> 409                 raise e.with_traceback(None)
    410 
    411         argtypes = []

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(class(float64), slice<a:b>)
 
There are 22 candidate implementations:
      - Of which 22 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(class(float64), slice<a:b>)':
       No match.

During: typing of intrinsic-call at /tmp/ipykernel_536115/3046996983.py (4)
During: typing of static-get-item at /tmp/ipykernel_536115/3046996983.py (4)

File "../../../../tmp/ipykernel_536115/3046996983.py", line 4:
<source missing, REPL/exec in use?>

即使我显式签名,我也会得到同样的错误

@numba.njit("float64[:]()")
def generate_d():

    d = Dict.empty(types.int64, types.float64[:])

    return d

我使用的是 numba v 0.55.1、numpy 1.20.3

我怎样才能让它工作?

据我所知,JIT 函数尚不支持类型表达式(Numba 版本 0.54.1)。您需要在函数外部创建类型。这是一个例子:

import numba
from numba.typed import Dict

# Type defined outside the JIT function
FloatArrayType = numba.types.float64[:]

@numba.njit
def generate_d():
    d = Dict.empty(numba.types.int64, FloatArrayType)  # <-- and used here
    return d