在 numba no-python 模式下对 numpy 数组进行排序

Sorting an numpy array in numba no-python mode

Numba documentation 建议编译以下代码

@njit()
def accuracy(x,y,z):
    x.argsort(axis=1)
    # compare accuracy, this code works without the above line  
    accuracy_y = int(((np.equal(y, x).mean())*100)%100)
    accuracy_z = int(((np.equal(z, x).mean())*100)%100)
    return accuracy_y,accuracy_z

它在 x.argsort() 上失败了,我还尝试了以下使用和不使用轴参数的方法

np.argsort(x)
np.sort(x)
x.sort()

但是我得到以下编译失败错误(或类似错误):

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function sort at 0x000001B3B2CD2EE0>) found for signature:
 
 >>> sort(array(int64, 2d, C))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload of function 'sort': File: numba\core\typing\npydecl.py: Line 665.
    With argument(s): '(array(int64, 2d, C))':
   No match.

During: resolving callee type: Function(<function sort at 0x000001B3B2CD2EE0>)



File "accuracy.py", line 148:
def accuracy(x,lm,sfn):
    <source elided>
    # accuracy
    np.sort(x)
    ^

我在这里错过了什么?

感谢您的评论! 以下函数对二维数组进行排序!

from numba import njit
import numpy  as np

@njit()
def sort_2d_array(x):
    n,m=np.shape(x)
    for row in range(n):
        x[row]=np.sort(x[row])
    return x

arr=np.array([[3,2,1],[6,5,4],[9,8,7]])
y=sort_2d_array(arr)
print(y)

如果适合您的用例,您也可以考虑使用 guvectorize。这带来了能够指定要排序的轴的好处。可以通过在不同的轴上重复调用来完成超过 1 个维度的排序。

@guvectorize("(n)->(n)")
def sort_array(x, out):
    out[:] = np.sort(x)

使用略有不同的示例数组,其中的列也是无序的。

arr = np.array([
    [6,5,4],
    [3,2,1],
    [9,8,7]],
)

sort_array(arr, out, axis=0)
sort_array(out, out, axis=1)

显示:

array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

制作一个简单的包装器将允许一次对任意数量的维度进行排序。我认为 Numba 的重载不支持 guvectorize,否则您甚至可以使用它使 np.sort 在您的 jitted 函数中工作而无需更改任何内容。

https://numba.pydata.org/numba-doc/dev/extending/overloading-guide.html

对比 Numpy 测试输出:

for _ in range(20):
    
    arr = np.random.randint(0, 99, (9,9))

    # numba sorting
    out_nb = np.empty_like(arr)
    sort_array(arr, out_nb, axis=0)
    sort_array(out_nb, out_nb, axis=1)

    # numpy sorting
    out_np = np.sort(arr, axis=0)
    out_np = np.sort(out_np, axis=1)

    np.testing.assert_array_equal(out_nb, out_np)