在 numba no-python 模式下对 numpy 数组进行排序
Sorting an numpy array in numba no-python mode
Numba documentation 建议编译以下代码
@njit()
def accuracy(x,y,z):
x.argsort(axis=1)
# compare accuracy, this code works without the above line
accuracy_y = int(((np.equal(y, x).mean())*100)%100)
accuracy_z = int(((np.equal(z, x).mean())*100)%100)
return accuracy_y,accuracy_z
它在 x.argsort()
上失败了,我还尝试了以下使用和不使用轴参数的方法
np.argsort(x)
np.sort(x)
x.sort()
但是我得到以下编译失败错误(或类似错误):
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function sort at 0x000001B3B2CD2EE0>) found for signature:
>>> sort(array(int64, 2d, C))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'sort': File: numba\core\typing\npydecl.py: Line 665.
With argument(s): '(array(int64, 2d, C))':
No match.
During: resolving callee type: Function(<function sort at 0x000001B3B2CD2EE0>)
File "accuracy.py", line 148:
def accuracy(x,lm,sfn):
<source elided>
# accuracy
np.sort(x)
^
我在这里错过了什么?
感谢您的评论!
以下函数对二维数组进行排序!
from numba import njit
import numpy as np
@njit()
def sort_2d_array(x):
n,m=np.shape(x)
for row in range(n):
x[row]=np.sort(x[row])
return x
arr=np.array([[3,2,1],[6,5,4],[9,8,7]])
y=sort_2d_array(arr)
print(y)
如果适合您的用例,您也可以考虑使用 guvectorize
。这带来了能够指定要排序的轴的好处。可以通过在不同的轴上重复调用来完成超过 1 个维度的排序。
@guvectorize("(n)->(n)")
def sort_array(x, out):
out[:] = np.sort(x)
使用略有不同的示例数组,其中的列也是无序的。
arr = np.array([
[6,5,4],
[3,2,1],
[9,8,7]],
)
sort_array(arr, out, axis=0)
sort_array(out, out, axis=1)
显示:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
制作一个简单的包装器将允许一次对任意数量的维度进行排序。我认为 Numba 的重载不支持 guvectorize
,否则您甚至可以使用它使 np.sort
在您的 jitted 函数中工作而无需更改任何内容。
https://numba.pydata.org/numba-doc/dev/extending/overloading-guide.html
对比 Numpy 测试输出:
for _ in range(20):
arr = np.random.randint(0, 99, (9,9))
# numba sorting
out_nb = np.empty_like(arr)
sort_array(arr, out_nb, axis=0)
sort_array(out_nb, out_nb, axis=1)
# numpy sorting
out_np = np.sort(arr, axis=0)
out_np = np.sort(out_np, axis=1)
np.testing.assert_array_equal(out_nb, out_np)
Numba documentation 建议编译以下代码
@njit()
def accuracy(x,y,z):
x.argsort(axis=1)
# compare accuracy, this code works without the above line
accuracy_y = int(((np.equal(y, x).mean())*100)%100)
accuracy_z = int(((np.equal(z, x).mean())*100)%100)
return accuracy_y,accuracy_z
它在 x.argsort()
上失败了,我还尝试了以下使用和不使用轴参数的方法
np.argsort(x)
np.sort(x)
x.sort()
但是我得到以下编译失败错误(或类似错误):
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function sort at 0x000001B3B2CD2EE0>) found for signature:
>>> sort(array(int64, 2d, C))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'sort': File: numba\core\typing\npydecl.py: Line 665.
With argument(s): '(array(int64, 2d, C))':
No match.
During: resolving callee type: Function(<function sort at 0x000001B3B2CD2EE0>)
File "accuracy.py", line 148:
def accuracy(x,lm,sfn):
<source elided>
# accuracy
np.sort(x)
^
我在这里错过了什么?
感谢您的评论! 以下函数对二维数组进行排序!
from numba import njit
import numpy as np
@njit()
def sort_2d_array(x):
n,m=np.shape(x)
for row in range(n):
x[row]=np.sort(x[row])
return x
arr=np.array([[3,2,1],[6,5,4],[9,8,7]])
y=sort_2d_array(arr)
print(y)
如果适合您的用例,您也可以考虑使用 guvectorize
。这带来了能够指定要排序的轴的好处。可以通过在不同的轴上重复调用来完成超过 1 个维度的排序。
@guvectorize("(n)->(n)")
def sort_array(x, out):
out[:] = np.sort(x)
使用略有不同的示例数组,其中的列也是无序的。
arr = np.array([
[6,5,4],
[3,2,1],
[9,8,7]],
)
sort_array(arr, out, axis=0)
sort_array(out, out, axis=1)
显示:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
制作一个简单的包装器将允许一次对任意数量的维度进行排序。我认为 Numba 的重载不支持 guvectorize
,否则您甚至可以使用它使 np.sort
在您的 jitted 函数中工作而无需更改任何内容。
https://numba.pydata.org/numba-doc/dev/extending/overloading-guide.html
对比 Numpy 测试输出:
for _ in range(20):
arr = np.random.randint(0, 99, (9,9))
# numba sorting
out_nb = np.empty_like(arr)
sort_array(arr, out_nb, axis=0)
sort_array(out_nb, out_nb, axis=1)
# numpy sorting
out_np = np.sort(arr, axis=0)
out_np = np.sort(out_np, axis=1)
np.testing.assert_array_equal(out_nb, out_np)