使用 Numba 提取 numpy 数组中的特定行

Extracting specific rows in numpy array using Numba

我有以下数组:

import numpy as np
from numba import njit


test_array = np.random.rand(4, 10)

我创建了一个 "jitted" 函数来对数组进行切片并在之后进行一些操作:

@njit(fastmath = True)
def test_function(array):

   test_array_sliced = test_array[[0,1,3]]

   return test_array_sliced

但是,Numba 抛出以下错误:

In definition 11:
    TypeError: unsupported array index type list(int64) in [list(int64)]
    raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.

解决方法

我尝试使用 np.delete 删除我不需要的行,但由于我必须指定 axis Numba 抛出以下错误:

@njit(fastmath = True)
def test_function(array):

   test_array_sliced = np.delete(test_array, obj = 2, axis = 0)

   return test_array_sliced

In definition 1:
    TypeError: np_delete() got an unexpected keyword argument 'axis'
    raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/templates.py:475
This error is usually caused by passing an argument of a type that is unsupported by the named function.

关于如何在 Numba 下提取特定行的任何想法?

Numba 不支持 numpy 花式索引。 我不是 100% 确定你的真实用例是什么样的,但一个简单的方法是这样的:

import numpy as np
import numba as nb

@nb.njit
def test_func(x):
    idx = (0, 1, 3)
    res = np.empty((len(idx), x.shape[1]), dtype=x.dtype)
    for i, ix in enumerate(idx):
        res[i] = x[ix]

    return res

test_array = np.random.rand(4, 10)
print(test_array)
print()
print(test_func(test_array))

编辑: @kwinkunks 是正确的,我原来的回答做了一个错误的笼统声明,即不支持花式索引。它出现在有限的一组案例中,包括这一个。

我认为如果您使用数组而不是列表进行索引,它会起作用(似乎建议如此in the docs):

test_array_sliced = array[np.array([0,1,3])]

(我将您要切片的数组更改为 array,这是您传递给函数的内容。也许是故意的,但要小心全局变量!)