如何在 numba njit 装饰器和 timeit 中指定 dict 类型不重用编译版本

How to specify dict type in numba njit decorator and timeit not reusing compiled version

我写了一个程序,它使用索引 numpy 数组中提到的索引将目标 numpy 数组中的条目分组在一起,returns 作为字典,键作为索引,值作为所有具有相同索引的条目。出于性能原因,我喜欢使用 numba 并且比其他 python 方法具有明显优势

我想在 njit 装饰器中输入类型以进行预编译。 我知道 numba 可以进行类型推断。

import numpy as np
import numba as nb
import random
import timeit


# uncomment njit argument to see error
@nb.njit#((nb.int64[:], nb.int64[:], nb.types.DictType[nb.int64, nb.int64[:]]))
def sort5(des, indices, d):
    present_indices = np.unique(indices)
    for i in nb.prange(present_indices.shape[0]):
        d[i] = des[indices==present_indices[i]]
    return d

indices = np.array([random.randint(0,9) for i in range(3500)])
des = np.array([i for i in range(len(indices))])
d = nb.typed.Dict.empty(
    key_type=nb.types.int64,
    value_type=nb.types.int64[:],
)

# Compiling but does not reuse for 9999 runs I think
print("sort5", timeit.timeit(lambda: sort5(des, indices, d), number=10000))
indices = np.array([random.randint(0,9) for i in range(3500)])
des = np.array([i for i in range(len(indices))])

# But only this timeit seems to use compiled version from previous timeit for all 10000 runs
print("sort5", timeit.timeit(lambda: sort5(des, indices, d), number=10000))

1:njit decorator中dict的类型应该怎么写

我试过的

# Trying to get type instance
nb.typeof(d)
# DictType[int64,array(int64, 1d, A)]<iv=None>

从上面我尝试用 @nb.njit((nb.int64[:], nb.int64[:], nb.types.DictType[nb.int64, nb.int64[:]])) 替换。我也尝试 nb.typed.Dict[nb.int64, nb.int64[:]] 但我得到类似的错误但是 ABCMeta class

TypeError: '_TypeMetaclass' object is not subscriptable
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-98e6f766338f> in <module>
      3 import random
      4 import timeit
----> 5 @nb.njit((nb.int64[:], nb.int64[:], nb.types.DictType[nb.int64, nb.int64[:]]))
      6 def sort5(des, indices, d):
      7     present_indices = np.unique(indices)

TypeError: '_TypeMetaclass' object is not subscriptable

我想我在这里遗漏了一些微不足道的东西但不确定

2:为什么第二次比较快

此外,如果你看到程序的输出,第二次它使用的是编译版本而不是第一次,因为第二次它在 1.17 秒内完成,而第一次需要 2.66

你只是括号错了

import numpy as np
import numba as nb
import random
import timeit


@nb.njit([(nb.int64[:], nb.int64[:], nb.types.DictType(nb.int64, nb.int64[:]))])
def sort5(des, indices, d):
    present_indices = np.unique(indices)
    for i in nb.prange(present_indices.shape[0]):
        d[i] = des[indices==present_indices[i]]
    return d

indices = np.array([random.randint(0,9) for i in range(3500)])
des = np.array([i for i in range(len(indices))])
d = nb.typed.Dict.empty(
    key_type=nb.types.int64,
    value_type=nb.types.int64[:],
)

# Compiling but does not reuse for 9999 runs I think
print("sort5", timeit.timeit(lambda: sort5(des, indices, d), number=10000))
indices = np.array([random.randint(0,9) for i in range(3500)])
des = np.array([i for i in range(len(indices))])

# But only this timeit seems to use compiled version from previous timeit for all 10000 runs
print("sort5", timeit.timeit(lambda: sort5(des, indices, d), number=10000))

给我

sort5 1.8259385739999914
sort5 1.9655513189999994

所以我想它也解决了你第二个问题的速度问题。

因为它帮助我解决了这个问题。我认为添加我使用以下内容来确定它推断出的类型可能会有用

sort5.overloads.keys()
odict_keys([(array(int64, 1d, A), array(int64, 1d, A), DictType[int64,array(int64, 1d, A)]<iv=None>)])