numba njit 在 2D np.array 索引上给出 my 和 error

numba njit give my and error on a 2D np.array indexation

我正在尝试在 njit 函数中索引一个二维矩阵 B,其中包含我想要的索引的向量 a,矩阵的一部分 D 这是一个最小的例子:

import numba as nb
import numpy as np

@nb.njit()
def test(N,P,B,D):
    for i in range(N):
        a = D[i,:]
        b =  B[i,a]
        P[:,i] =b

P = np.zeros((5,5))
B = np.random.random((5,5))*100
D = (np.random.random((5,5))*5).astype(np.int32)
print(D)
N = 5
print(P)
test(N,P,B,D)
print(P)

我在 b = B[i,a]

行收到 numba 错误
File "dj.py", line 10:
def test(N,P,B,D):
    <source elided>
        a = D[i,:]
        b =  B[i,a]
        ^

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

我不明白我在这里做错了什么。 该代码在没有 @nb.njit() 装饰器

的情况下工作

numba 不支持 numpy 支持的所有 "fancy-indexing" - 在这种情况下,问题是使用 a 数组选择数组元素。

对于你的特殊情况,因为你事先知道 b 的形状,你可以这样解决:

import numba as nb
import numpy as np

@nb.njit
def test(N,P,B,D):
    b = np.empty(D.shape[1], dtype=B.dtype)

    for i in range(N):
        a = D[i,:]
        for j in range(a.shape[0]):
            b[j] = B[i, j]
        P[:, i] = b

另一种解决方案是在调用测试之前在 B 上应用 swapaxes 并反转索引 (B[i,a] -> B[a,i])。我不知道为什么会这样,但这是实现:

import numba as nb
import numpy as np

@nb.njit()
def test(N,P,B,D):
    for i in range(N):
        a = D[i,:]
        b =  B[a,i]
        P[:, i] = b
    
P = np.zeros((5,5))
B = np.arange(25).reshape((5,5))
D = (np.random.random((5,5))*5).astype(np.int32)
N = 5
test(N,P,np.swapaxes(B, 0, 1), D)

对了,在@chrisb给出的答案中,不是:b[j] = B[i, j]而是b[j] = B[i, a[j]].