如何在 numba njit 函数中使用元组键创建字典

How to create a dictionary with tuple keys in a numba njit fuction

我对 numba 非常缺乏经验(和发布问题)所以希望这不是一个未指定的问题。

我正在尝试创建一个涉及字典的 jitted 函数。我希望字典将元组作为键,将浮点数作为值。下面是从 numba 帮助中找到的一些代码 on the numba docs,我用它来帮助证明我的问题。

我了解 numba 希望指定变量类型。我认为的问题是我没有将正确的 numba 类型指定为函数内的字典键。我查看了 this question,但仍然不知道该怎么做。

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

# Make array type.  Type-expression is not supported in jit functions.
float_array = types.float64[:]

@njit
def foo():
    list_out=[]
    # Make dictionary
    d = Dict.empty(
        key_type=types.Tuple, #<= I suppose im not putting the right 'type' here
        value_type=float_array,
    )
    # an example of how I would like to fill the dictionary
    d[(1,1)] = np.arange(3).astype(np.float64)
    d[(2,2)] = np.arange(3, 6).astype(np.float64)
    list_out.append(d[(2,2)])
    return list_out

list_out = foo()

感谢任何帮助或指导。感谢您的宝贵时间!

types.Tuple 不完整类型 ,因此不是有效类型。您需要指定元组中项目的类型。在这种情况下,您可以使用 types.UniTuple(types.int32, 2) 作为完整的键类型(包含两个 32 位整数的元组)。这是结果代码:

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

# Make key type with two 32-bit integer items.
key_type = types.UniTuple(types.int32, 2)

# Make array type.  Type-expression is not supported in jit functions.
float_array = types.float64[:]

@njit
def foo():
    list_out=[]
    # Make dictionary
    d = Dict.empty(
        key_type=key_type, 
        value_type=float_array,
    )
    # an example of how I would like to fill the dictionary
    d[(1,1)] = np.arange(3).astype(np.float64)
    d[(2,2)] = np.arange(3, 6).astype(np.float64)
    list_out.append(d[(2,2)])
    return list_out

list_out = foo()

顺便说一句,请注意 arange 在参数中接受一个 dtype,因此您可以直接使用 np.arange(3, dtype=np.float64),这在使用 astype 时效率更高。