在 numba 的 jitclass 中索引多维 numpy 数组
Indexing multidimensional numpy array inside numba's jitclass
我正在尝试将一个小的多维数组插入到 numba jitclass 中的一个更大的数组中。小数组设置了索引列表定义的大数组的特定位置。
以下 MWE 显示了没有 numba 的问题 - 一切都按预期工作
import numpy as np
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution 1 using pure python
def nonNumbaFunction1(self, idx, values):
self.A[idx[:, None], idx] = values
# solution 2 using pure python
def nonNumbaFunction2(self, idx, values):
self.A[np.ix_(idx, idx)] = values
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.nonNumbaFunction1(idx, values)
print(f'A =\n{obj.A}')
obj.nonNumbaFunction2(idx, values)
print(f'A =\n{obj.A}')
nonNumbaFunction1
和 nonNumbaFunction2
这两个函数在 numba class 中都不起作用。所以我目前的解决方案看起来像这样,在我看来这并不是很好
import numpy as np
from numba import jitclass
from numba import int64, float64
from collections import OrderedDict
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution for numba jitclass
def numbaFunction(self, idx, values):
for i in range(len(values)):
idxi = idx[i]
for j in range(len(values)):
idxj = idx[j]
self.A[idxi, idxj] = values[i, j]
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.numbaFunction(idx, values)
print(f'A =\n{obj.A}')
所以我的问题是:
- 有谁知道 numba 中这个索引的解决方案或者是否有其他矢量化解决方案?
nonNumbaFunction1
有更快的解决方案吗?
了解插入的数组很小(4x4 到 10x10)可能会有用,但此索引出现在嵌套循环中,因此它也必须快速安静!稍后我也需要对三维对象进行类似的索引。
由于 numba 索引支持的限制,我认为您没有比自己编写 for 循环更好的方法了。要使其跨维度通用,您可以使用 generated_jit
装饰器进行专门化。像这样:
def set_2d(target, values, idx):
for i in range(values.shape[0]):
for j in range(values.shape[1]):
target[idx[i], idx[j]] = values[i, j]
def set_3d(target, values, idx):
for i in range(values.shape[0]):
for j in range(values.shape[1]):
for k in range(values.shape[2]):
target[idx[i], idx[j], idx[k]] = values[i, j, l]
@numba.generated_jit
def set_nd(target, values, idx):
if target.ndim == 2:
return set_2d
elif target.ndim == 3:
return set_3d
然后,这可以用在你的 jitclass 中
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
def numbaFunction(self, idx, values):
set_nd(self.A, values, idx)
我正在尝试将一个小的多维数组插入到 numba jitclass 中的一个更大的数组中。小数组设置了索引列表定义的大数组的特定位置。
以下 MWE 显示了没有 numba 的问题 - 一切都按预期工作
import numpy as np
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution 1 using pure python
def nonNumbaFunction1(self, idx, values):
self.A[idx[:, None], idx] = values
# solution 2 using pure python
def nonNumbaFunction2(self, idx, values):
self.A[np.ix_(idx, idx)] = values
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.nonNumbaFunction1(idx, values)
print(f'A =\n{obj.A}')
obj.nonNumbaFunction2(idx, values)
print(f'A =\n{obj.A}')
nonNumbaFunction1
和 nonNumbaFunction2
这两个函数在 numba class 中都不起作用。所以我目前的解决方案看起来像这样,在我看来这并不是很好
import numpy as np
from numba import jitclass
from numba import int64, float64
from collections import OrderedDict
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution for numba jitclass
def numbaFunction(self, idx, values):
for i in range(len(values)):
idxi = idx[i]
for j in range(len(values)):
idxj = idx[j]
self.A[idxi, idxj] = values[i, j]
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.numbaFunction(idx, values)
print(f'A =\n{obj.A}')
所以我的问题是:
- 有谁知道 numba 中这个索引的解决方案或者是否有其他矢量化解决方案?
nonNumbaFunction1
有更快的解决方案吗?
了解插入的数组很小(4x4 到 10x10)可能会有用,但此索引出现在嵌套循环中,因此它也必须快速安静!稍后我也需要对三维对象进行类似的索引。
由于 numba 索引支持的限制,我认为您没有比自己编写 for 循环更好的方法了。要使其跨维度通用,您可以使用 generated_jit
装饰器进行专门化。像这样:
def set_2d(target, values, idx):
for i in range(values.shape[0]):
for j in range(values.shape[1]):
target[idx[i], idx[j]] = values[i, j]
def set_3d(target, values, idx):
for i in range(values.shape[0]):
for j in range(values.shape[1]):
for k in range(values.shape[2]):
target[idx[i], idx[j], idx[k]] = values[i, j, l]
@numba.generated_jit
def set_nd(target, values, idx):
if target.ndim == 2:
return set_2d
elif target.ndim == 3:
return set_3d
然后,这可以用在你的 jitclass 中
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
def numbaFunction(self, idx, values):
set_nd(self.A, values, idx)