在 numba 的 jitclass 中索引多维 numpy 数组

Indexing multidimensional numpy array inside numba's jitclass

我正在尝试将一个小的多维数组插入到 numba jitclass 中的一个更大的数组中。小数组设置了索引列表定义的大数组的特定位置。

以下 MWE 显示了没有 numba 的问题 - 一切都按预期工作

import numpy as np

class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution 1 using pure python
    def nonNumbaFunction1(self, idx, values):
        self.A[idx[:, None], idx] = values

    # solution 2 using pure python
    def nonNumbaFunction2(self, idx, values):
        self.A[np.ix_(idx, idx)] = values

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.nonNumbaFunction1(idx, values)
    print(f'A =\n{obj.A}')

    obj.nonNumbaFunction2(idx, values)
    print(f'A =\n{obj.A}')

nonNumbaFunction1nonNumbaFunction2 这两个函数在 numba class 中都不起作用。所以我目前的解决方案看起来像这样,在我看来这并不是很好

import numpy as np

from numba import jitclass      
from numba import int64, float64
from collections import OrderedDict

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution for numba jitclass
    def numbaFunction(self, idx, values):
        for i in range(len(values)):
            idxi = idx[i]
            for j in range(len(values)):
                idxj = idx[j]
                self.A[idxi, idxj] = values[i, j]

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.numbaFunction(idx, values)
    print(f'A =\n{obj.A}')

所以我的问题是:

了解插入的数组很小(4x4 到 10x10)可能会有用,但此索引出现在嵌套循环中,因此它也必须快速安静!稍后我也需要对三维对象进行类似的索引。

由于 numba 索引支持的限制,我认为您没有比自己编写 for 循环更好的方法了。要使其跨维度通用,您可以使用 generated_jit 装饰器进行专门化。像这样:

def set_2d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            target[idx[i], idx[j]] = values[i, j]

def set_3d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            for k in range(values.shape[2]):
                target[idx[i], idx[j], idx[k]] = values[i, j, l]

@numba.generated_jit
def set_nd(target, values, idx):
    if target.ndim == 2:
        return set_2d
    elif target.ndim == 3:
        return set_3d

然后,这可以用在你的 jitclass 中

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):
    def __init__(self, n, m):
        self.A = np.zeros((n, m))
    def numbaFunction(self, idx, values):
        set_nd(self.A, values, idx)