使用 numba 索引 numpy 数组时出现 TypeError
TypeError when indexing numpy array using numba
我需要根据另一个包含 class 成员资格 (labels
) 信息的数组对一维 numpy
数组(下图:data
)中的元素求和。我在下面的代码中使用 numba
来加快速度。但是,如果我没有在 ret[int(find(labels, g))] += y
行中使用 int()
显式转换,我会收到一条错误消息:
TypeError: unsupported array index type ?int64
是否有比显式转换更好的解决方法?
import numpy as np
from numba import jit
labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)
@jit(nopython=True)
def find(seq, value):
for ct, x in enumerate(seq):
if x == value:
return ct
@jit(nopython=True)
def subsumNumba(data, groups, labels):
ret = np.zeros(len(labels))
for y, g in zip(data, groups):
# not working without casting with int()
ret[int(find(labels, g))] += y
return ret
问题是 find
可以 return 一个 int
或 None
如果它没有找到任何东西,因此我认为 ?int64
错误。为避免强制转换,您需要在 find
未找到所需值而退出时提供一个 int
return 值,然后在调用者中处理它。
我需要根据另一个包含 class 成员资格 (labels
) 信息的数组对一维 numpy
数组(下图:data
)中的元素求和。我在下面的代码中使用 numba
来加快速度。但是,如果我没有在 ret[int(find(labels, g))] += y
行中使用 int()
显式转换,我会收到一条错误消息:
TypeError: unsupported array index type ?int64
是否有比显式转换更好的解决方法?
import numpy as np
from numba import jit
labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)
@jit(nopython=True)
def find(seq, value):
for ct, x in enumerate(seq):
if x == value:
return ct
@jit(nopython=True)
def subsumNumba(data, groups, labels):
ret = np.zeros(len(labels))
for y, g in zip(data, groups):
# not working without casting with int()
ret[int(find(labels, g))] += y
return ret
问题是 find
可以 return 一个 int
或 None
如果它没有找到任何东西,因此我认为 ?int64
错误。为避免强制转换,您需要在 find
未找到所需值而退出时提供一个 int
return 值,然后在调用者中处理它。