使用 Numba 提取 numpy 数组中的特定行
Extracting specific rows in numpy array using Numba
我有以下数组:
import numpy as np
from numba import njit
test_array = np.random.rand(4, 10)
我创建了一个 "jitted" 函数来对数组进行切片并在之后进行一些操作:
@njit(fastmath = True)
def test_function(array):
test_array_sliced = test_array[[0,1,3]]
return test_array_sliced
但是,Numba 抛出以下错误:
In definition 11:
TypeError: unsupported array index type list(int64) in [list(int64)]
raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
解决方法
我尝试使用 np.delete
删除我不需要的行,但由于我必须指定 axis
Numba 抛出以下错误:
@njit(fastmath = True)
def test_function(array):
test_array_sliced = np.delete(test_array, obj = 2, axis = 0)
return test_array_sliced
In definition 1:
TypeError: np_delete() got an unexpected keyword argument 'axis'
raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/templates.py:475
This error is usually caused by passing an argument of a type that is unsupported by the named function.
关于如何在 Numba 下提取特定行的任何想法?
Numba 不支持 numpy 花式索引。 我不是 100% 确定你的真实用例是什么样的,但一个简单的方法是这样的:
import numpy as np
import numba as nb
@nb.njit
def test_func(x):
idx = (0, 1, 3)
res = np.empty((len(idx), x.shape[1]), dtype=x.dtype)
for i, ix in enumerate(idx):
res[i] = x[ix]
return res
test_array = np.random.rand(4, 10)
print(test_array)
print()
print(test_func(test_array))
编辑: @kwinkunks 是正确的,我原来的回答做了一个错误的笼统声明,即不支持花式索引。它出现在有限的一组案例中,包括这一个。
我认为如果您使用数组而不是列表进行索引,它会起作用(似乎建议如此in the docs):
test_array_sliced = array[np.array([0,1,3])]
(我将您要切片的数组更改为 array
,这是您传递给函数的内容。也许是故意的,但要小心全局变量!)
我有以下数组:
import numpy as np
from numba import njit
test_array = np.random.rand(4, 10)
我创建了一个 "jitted" 函数来对数组进行切片并在之后进行一些操作:
@njit(fastmath = True)
def test_function(array):
test_array_sliced = test_array[[0,1,3]]
return test_array_sliced
但是,Numba 抛出以下错误:
In definition 11:
TypeError: unsupported array index type list(int64) in [list(int64)]
raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
解决方法
我尝试使用 np.delete
删除我不需要的行,但由于我必须指定 axis
Numba 抛出以下错误:
@njit(fastmath = True)
def test_function(array):
test_array_sliced = np.delete(test_array, obj = 2, axis = 0)
return test_array_sliced
In definition 1:
TypeError: np_delete() got an unexpected keyword argument 'axis'
raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/templates.py:475
This error is usually caused by passing an argument of a type that is unsupported by the named function.
关于如何在 Numba 下提取特定行的任何想法?
Numba 不支持 numpy 花式索引。 我不是 100% 确定你的真实用例是什么样的,但一个简单的方法是这样的:
import numpy as np
import numba as nb
@nb.njit
def test_func(x):
idx = (0, 1, 3)
res = np.empty((len(idx), x.shape[1]), dtype=x.dtype)
for i, ix in enumerate(idx):
res[i] = x[ix]
return res
test_array = np.random.rand(4, 10)
print(test_array)
print()
print(test_func(test_array))
编辑: @kwinkunks 是正确的,我原来的回答做了一个错误的笼统声明,即不支持花式索引。它出现在有限的一组案例中,包括这一个。
我认为如果您使用数组而不是列表进行索引,它会起作用(似乎建议如此in the docs):
test_array_sliced = array[np.array([0,1,3])]
(我将您要切片的数组更改为 array
,这是您传递给函数的内容。也许是故意的,但要小心全局变量!)