NumbaNotImplementedError: only one advanced index supported -> how to rewrite a 3D [x,y,x] -> 2D array replacement of values numba can handle

NumbaNotImplementedError: only one advanced index supported -> how to rewrite a 3D [x,y,x] -> 2D array replacement of values numba can handle

我正在使用 parallel=True 将一些奇怪的代码转换为与 Numba 兼容。它有一个有问题的数组分配,我不太清楚如何以 numba 可以处理的方式重写。我试图解码错误的含义,但我迷路了。唯一清楚的是它不喜欢这条线:Averaging_price_3D[leg, :, expired_loc] = last_non_expired_values.T 错误很长,包含在这里以供参考:

TypingError: No implementation of function Function(<built-in function setitem>) found for signature:

setitem(array(float64, 3d, C), Tuple(int64, slice<a:b>, array(int64, 1d, C)), array(float64, 2d, F))

There are 16 candidate implementations:
  - Of which 14 did not match due to:
  Overload of function 'setitem': File: <numerous>: Line N/A.
    With argument(s): '(array(float64, 3d, C), Tuple(int64, slice<a:b>, array(int64, 1d, C)), array(float64, 2d, F))':
   No match.
  - Of which 2 did not match due to:
  Overload in function 'SetItemBuffer.generic': File: numba\core\typing\arraydecl.py: Line 176.
    With argument(s): '(array(float64, 3d, C), Tuple(int64, slice<a:b>, array(int64, 1d, C)), array(float64, 2d, F))':
   Rejected as the implementation raised a specific error:
     NumbaNotImplementedError: only one advanced index supported

这里是重现错误的一小段代码:

import numpy as np
import numba as nb

@nb.jit(nopython=True, parallel=True, nogil=True)
def main(Averaging_price_3D, expired_loc, last_non_expired_values):

    for leg in range(Averaging_price_3D.shape[0]):
        # line below causes the numba error:
        Averaging_price_3D[leg, :, expired_loc] = last_non_expired_values.T
    return Averaging_price_3D

if __name__ == "__main__":
    Averaging_price_3D=np.random.rand(2,8192,11)*100 # shape (2,8192,11) 3D array float64
    expired_loc=np.arange(4,10).astype(np.int64) # shape (6,) 1D array int64
    last_non_expired_values = Averaging_price_3D[1,:,0:expired_loc.shape[0]].copy() # shape (8192,6) 2D array float64

    result = main(Averaging_price_3D, expired_loc, last_non_expired_values)

现在我对这个错误的最好解释是“numba 不知道如何使用数组索引和 2D 数组中的值来设置 3D 矩阵中的值。”但是我在网上搜索了很多,找不到另一种方法来完成同样的事情,而不会导致 numba 崩溃。

在其他像这样的情况下,我在索引之前使用 .reshape(-1) 来展平数组,但我在弄清楚在这种特定情况下如何做到这一点时遇到了问题(这很容易用另一个 3D 数组索引的 3D 数组,因为它们都会以相同的顺序展平)...感谢任何帮助!

很有趣,我查看了传递给 3D 数组的索引(因为错误提示“只支持一个高级索引”,所以我选择检查我的索引):

3Darray[int, :, 1Darray]

看到 numba 很挑剔,我试着重写了一点,所以一维数组没有被用作索引(显然,这是一个“高级索引”,所以使用 int指数)。阅读 numba 错误和解决方案,他们倾向于添加循环,所以我在这里尝试了。因此,我没有将一维数组作为索引传递,而是遍历了一维数组的元素:

import numpy as np
import numba as nb


@nb.jit(cache=True, nopython=True, parallel=True, nogil=True)
def main(Averaging_price_3D, expired_loc, last_non_expired_values):

    for leg in nb.prange(Averaging_price_3D.shape[0]):
        # change the indexing for numba to get rid of a 1D array index
        for i in nb.prange(expired_loc.shape[0]): 
            # now we assign values 3Darray[int,:,int] = 1Darray  
            Averaging_price_3D[leg, :, expired_loc[i]] = last_non_expired_values[:,i].T 
    return Averaging_price_3D

if __name__ == "__main__":
    Averaging_price_3D=np.random.rand(2,8192,11)*100 # shape (2,8192,11) 3D array float64
    expired_loc=np.arange(4,10).astype(np.int64) # shape (6,) 1D array int64
    last_non_expired_values = Averaging_price_3D[1,:,0:expired_loc.shape[0]] # shape (8192,6) 2D array float64

    result = main(Averaging_price_3D, expired_loc, last_non_expired_values)

现在完全没问题了。所以在我看来,如果你想使用 numba 访问 3D 数组中的元素,你应该使用 ints: 来完成。它似乎不喜欢 1D 数组索引,所以用循环替换它,它应该 运行 并行。