使用 numba 索引 numpy 数组时出现 TypeError

TypeError when indexing numpy array using numba

我需要根据另一个包含 class 成员资格 (labels) 信息的数组对一维 numpy 数组(下图:data)中的元素求和。我在下面的代码中使用 numba 来加快速度。但是,如果我没有在 ret[int(find(labels, g))] += y 行中使用 int() 显式转换,我会收到一条错误消息:

TypeError: unsupported array index type ?int64

是否有比显式转换更好的解决方法?

import numpy as np
from numba import jit

labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)

@jit(nopython=True)
def find(seq, value):
    for ct, x in enumerate(seq):
        if x == value:
            return ct

@jit(nopython=True)
def subsumNumba(data, groups, labels):
    ret = np.zeros(len(labels))
    for y, g in zip(data, groups):
        # not working without casting with int()
        ret[int(find(labels, g))] += y
    return ret

问题是 find 可以 return 一个 intNone 如果它没有找到任何东西,因此我认为 ?int64错误。为避免强制转换,您需要在 find 未找到所需值而退出时提供一个 int return 值,然后在调用者中处理它。