select 数组中最小的 n 个元素的最快方法是什么?

What is the fastest way to select the smallest n elements from an array?

我使用 numba 编写 quick select algorithm 很开心,想分享结果。

考虑数组x

np.random.seed([3,1415])
x = np.random.permutation(np.arange(10))
x

array([9, 4, 5, 1, 7, 6, 8, 3, 2, 0])

拉取最小的 n 个元素的最快方法是什么。

我试过了
np.partition

np.partition(x, 5)[:5]

array([0, 1, 2, 3, 4])

pd.Series.nsmallest

pd.Series(x).nsmallest(5).values

array([0, 1, 2, 3, 4])

更新
@user2357112 在评论中指出我的函数是在原地操作。结果证明这就是我的性能提升的来源。所以最后,我们从 quickselectnumba 的粗略实现中获得了非常相似的性能。仍然没有什么可打喷嚏的,但不是我所希望的。


正如我在问题中所说的那样,我正在弄乱 numba 并想分享我的发现。

请注意,我导入了 njit 而不是 jit。这是一个装饰器,可以自动防止自己回退到原生 python 对象。这意味着当它进行加速时,它只会使用它实际上可以加速的东西。这反过来意味着当我弄清楚什么是允许的和什么是不允许的时,我的功能经常失败。

到目前为止,我认为用 numbas jitnjit 写东西是挑剔和困难的,但当你看到一个不错的表现时,有点值得回报。

这是我的快速而肮脏的 quickselect 函数

import numpy as np
from numba import njit
import pandas as pd
import numexpr as ne

@njit
def rselect(a, k):
    n = len(a)
    if n <= 1:
        return a
    elif k > n:
        return a
    else:
        p = np.random.randint(n)
        pivot = a[p]
        a[0], a[p] = a[p], a[0]
        i = j = 1
        while j < n:
            if a[j] < pivot:
                a[j], a[i] = a[i], a[j]
                i += 1
            j += 1
        a[i-1], a[0] = a[0], a[i-1]
        if i - 1 <= k <= i:
            return a[:k]
        elif k > i:
            return np.concatenate((a[:i], rselect(a[i:], k - i)))
        else:
            return rselect(a[:i-1], k)

您会注意到它 returns 与问题中的方法相同的元素。

rselect(x, 5)

array([2, 1, 0, 3, 4])

速度呢?

def nsmall_np(x, n):
    return np.partition(x, n)[:n]

def nsmall_pd(x, n):
    pd.Series(x).nsmallest().values

def nsmall_pir(x, n):
    return rselect(x.copy(), n)


from timeit import timeit


results = pd.DataFrame(
    index=pd.Index([100, 1000, 3000, 6000, 10000, 100000, 1000000], name='Size'),
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir'], name='Method')
)

for i in results.index:
    x = np.random.rand(i)
    n = i // 2
    for j in results.columns:
        stmt = '{}(x, n)'.format(j)
        setp = 'from __main__ import {}, x, n'.format(j)
        results.set_value(
            i, j, timeit(stmt, setp, number=1000)
        )

results

Method   nsmall_np  nsmall_pd  nsmall_pir
Size                                     
100       0.003873   0.336693    0.002941
1000      0.007683   1.170193    0.011460
3000      0.016083   0.309765    0.029628
6000      0.050026   0.346420    0.059591
10000     0.106036   0.435710    0.092076
100000    1.064301   2.073206    0.936986
1000000  11.864195  27.447762   12.755983

results.plot(title='Selection Speed', colormap='jet', figsize=(10, 6))

.png

一般来说,我不建议尝试打败 NumPy。很少有人可以竞争(对于长数组),找到更快的实现就更少了。即使速度更快,也可能不会快 2 倍。所以很少值得。

但是我最近尝试自己做这样的事情,所以我可以分享一些有趣的结果。

这不是我自己想出来的。我的方法基于 numbas (re-)implementation of np.median他们可能知道他们在做什么。

我最后得到的是:

import numba as nb
import numpy as np

@nb.njit
def _partition(A, low, high):
    """copied from numba source code"""
    mid = (low + high) >> 1
    if A[mid] < A[low]:
        A[low], A[mid] = A[mid], A[low]
    if A[high] < A[mid]:
        A[high], A[mid] = A[mid], A[high]
        if A[mid] < A[low]:
            A[low], A[mid] = A[mid], A[low]
    pivot = A[mid]

    A[high], A[mid] = A[mid], A[high]

    i = low
    for j in range(low, high):
        if A[j] <= pivot:
            A[i], A[j] = A[j], A[i]
            i += 1

    A[i], A[high] = A[high], A[i]
    return i

@nb.njit
def _select_lowest(arry, k, low, high):
    """copied from numba source code, slightly changed"""
    i = _partition(arry, low, high)
    while i != k:
        if i < k:
            low = i + 1
            i = _partition(arry, low, high)
        else:
            high = i - 1
            i = _partition(arry, low, high)
    return arry[:k]

@nb.njit
def _nlowest_inner(temp_arry, n, idx):
    """copied from numba source code, slightly changed"""
    low = 0
    high = n - 1
    return _select_lowest(temp_arry, idx, low, high)

@nb.njit
def nlowest(a, idx):
    """copied from numba source code, slightly changed"""
    temp_arry = a.flatten()  # does a copy! :)
    n = temp_arry.shape[0]
    return _nlowest_inner(temp_arry, n, idx)

我在计时之前加入了一些热身电话。预热所以编译时间不包括在计时中:

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

我有一台(慢很多)的电脑,我稍微改变了元素的数量和重复的次数。但结果似乎表明我(好吧,numba 开发人员确实)打败了 NumPy:

results = pd.DataFrame(
    index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'),
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method')
)

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

for i in results.index:
    x = np.random.rand(i)
    n = i // 2
    for j in results.columns:
        stmt = '{}(x, n)'.format(j)
        setp = 'from __main__ import {}, x, n'.format(j)
        results.set_value(i, j, timeit(stmt, setp, number=100))

print(results)

Method   nsmall_np nsmall_pd  nsmall_pir      nlowest
Size                                                 
100     0.00343059  0.561372  0.00190855  0.000935566
500     0.00428461   1.79398  0.00326862   0.00187225
1000    0.00560669   3.36844  0.00432595   0.00364284
5000     0.0132515  0.305471   0.0142569    0.0108995
10000    0.0255161  0.340215    0.024847    0.0248285
50000     0.105937  0.543337    0.150277     0.118294
100000      0.2452  0.835571    0.333697     0.248473
500000     1.75214   3.50201     2.20235      1.44085