Python 中多键排序的效率

Efficiency of sorting by multiple keys in Python

我有一个字符串列表,我想通过 Python 3.6 中的两个自定义键函数对其进行排序。比较多排序方法(按较小的键排序,然后按主键排序)与多键方法(将键作为元组 (major_key, lesser_key)),我可以看到后者比主键慢 2 倍以上前者,这是一个惊喜,因为我认为它们是等价的。我想明白为什么会这样。

import random
from time import time

largest = 1000000
length = 10000000
start = time()
lst = [str(x) for x in random.choices(range(largest), k=length)]
t0 = time() - start

start = time()
tmp = sorted(lst, key=lambda x: x[::2])
l1 = sorted(tmp, key=lambda x: ''.join(sorted(x)))
t1 = time() - start

start = time()
l2 = sorted(lst, key=lambda x: (''.join(sorted(x)), x[::2]))
t2 = time() - start

print(f'prepare={t0} multisort={t1} multikey={t2} slowdown={t2/t1}')

assert l1 == l2

更新

and 促使我进行更多分析。以下是不同比较的计数(每个元素)和其他统计数据:

               length:     10,000       100,000      1,000,000     10,000,000
                      | msort  mkey | msort  mkey | msort  mkey | msort  mkey |
----------------------+-------------+-------------+-------------+-------------+
 < ''.join(sorted(x)) | 11.64 11.01 | 13.86 11.88 | 13.10 12.00 | 12.16 12.06 |
 < x[::2]             | 11.99  0.96 | 13.51  3.20 | 13.77  5.35 | 13.79  6.68 |
== ''.join(sorted(x)) |       11.99 |       15.29 |       18.60 |       21.26 |
== x[::2]             |        0.98 |        3.42 |        6.60 |        9.20 |
----------------------+-------------+-------------+-------------+-------------+
time, μs per element  |  0.84  1.03 |  0.95  1.42 |  1.31  2.35 |  1.39  3.15 |
----------------------+-------------+-------------+-------------+-------------+
tracemalloc peak MiB  |  0.82  1.67 |  8.26 17.71 |  82.7 178.0 |   826  1780 |
----------------------+-------------+-------------+-------------+-------------+
sys.getsizeof MiB     |  0.58  1.80 |  5.72 17.82 |  57.6 179.5 |   581  1808 |
                      |  0.60       |  6.00       |  60.4       |   608       |
----------------------+-------------+-------------+-------------+-------------+
比较次数

每个元素的比较次数,即我计算了所有次比较并除以列表大小。因此,例如对于具有一百万个元素的列表,multisort 对每个 x[::2] 字符串进行了 13.77 < 次比较,然后对每个 ''.join(sorted(x)) 字符串进行了 13.10 < 次比较。而 multikey 有很多 == 的字符串比较和较少的 < 字符串比较。正如 Tim 指出的那样,unsafe_tuple_compare uses the slower PyObject_RichCompareBool(..., Py_EQ) before it uses the faster unsafe_latin_compare(在 ''.join(sorted(x)) 字符串上)或另一个较慢的 PyObject_RichCompareBool(..., Py_LT)(在 x[::2] 字符串上)。

需要注意的关键是 multisort 每个元素的比较次数大致保持不变,它只使用更快的 unsafe_latin_comparemultikey 除了 ''.join(sorted(x)) 上的 < 比较之外的比较数量增长迅速,包括额外的较慢的相等比较。

这与您的数据有关,因为 x 是 0 到 1,000,000 之间的整数的字符串表示形式。这会导致越来越多的重复项,对于 ''.join(sorted(x))(也有来自排序的重复项,例如 314159 和 911345 都变成“113459”)和 x[::2](也有来自切片的重复项,例如,123456 和 183957 都变成“135”)。

对于 multisort 这很好,因为重复意味着更少的工作 - 相等的东西不需要 “进一步排序”(我认为相等的条纹这对 Timsort 的疾驰来说是件好事。

但是 multikey 受到重复的困扰,因为 ==''.join(sorted(x)) 的比较更频繁地导致“它们相等”,从而导致对 x[::2] 字符串。这些 x[::2] 通常结果是 不相等 ,请注意 < x[::2] 的数字与 == x[::2] 比较相比如何相对较大(例如,一千万元素,有 9.20 == 次比较和 6.68 < 次比较,所以 73% 的时间它们是不相等的)。这意味着 元组 也常常不相等,因此它们确实需要 “进一步排序”(我认为这对奔跑不利)。这种进一步的排序将再次比较 整个元组 ,也意味着 甚至更多 == ''.join(sorted(x)) 字符串的比较,即使他们是平等的!这就是为什么在 multikey 中,''.join(sorted(x)) 字符串的 < 比较仍然相当 stable(从 10,000 个元素的每个字符串 11.01,到 1000 万个元素的 12.06)而他们的 == 比较增长如此之多(从 11.99 到 21.26)。

时代与记忆

上面table中的运行时间也反映了multisort的小增长和multikey的大增长。有一次 multisort 确实慢了很多是在从 100,000 到 1,000,000 个元素的步骤中,每个元素从 0.95 μs 到 1.31 μs。从我的 table 的 tracemallocsys.getsizeof 行可以看出,它的内存从 ~6 MiB 增加到 ~59 MiB,而 CPU 大约有 34 MiB缓存,因此缓存未命中可能是造成这种情况的原因。

对于 sys.getsizeof 值,我将所有键放入列表并添加列表的大小(因为 sorted 也存储它们)和所有 strings/tuples 的大小列表。对于multisort,我的table分别显示两个值,因为两个排序一个接一个,第一次排序的内存在第二次排序之前释放。

(注意:运行时与我原来的答案不同,因为我不得不使用不同的计算机对一千万个元素进行排序。)

比较计数代码

我将字符串包装在一个 S 对象中,该对象会为每次比较增加一个计数器。

import random
from collections import Counter

class S:
    def __init__(self, value, label):
        self.value = value
        self.label = label
    def __eq__(self, other):
        comparisons['==', self.label] += 1
        return self.value == other.value
    def __ne__(self, other):
        comparisons['!=', self.label] += 1
        return self.value != other.value
    def __lt__(self, other):
        comparisons['<', self.label] += 1
        return self.value < other.value
    def __le__(self, other):
        comparisons['<=', self.label] += 1
        return self.value <= other.value
    def __ge__(self, other):
        comparisons['>=', self.label] += 1
        return self.value >= other.value
    def __gt__(self, other):
        comparisons['>', self.label] += 1
        return self.value > other.value

def key1(x):
    return S(''.join(sorted(x)), "''.join(sorted(x))")

def key2(x):
    return S(x[::2], 'x[::2]')

def multisort(l):
    tmp = sorted(l, key=lambda x: key2(x))
    return sorted(tmp, key=lambda x: key1(x))

def multikey(l):
    return sorted(l, key=lambda x: (key1(x), key2(x)))

funcs = [
    multisort,
    multikey,
]

def test(length, rounds):
    print(f'{length = :,}')

    largest = 1000000

    for _ in range(3):
        lst = list(map(str, random.choices(range(largest + 1), k=length)))
        for func in funcs:
            global comparisons
            comparisons = Counter()
            func(lst)
            print(func.__name__)
            for comparison, frequency in comparisons.most_common():
                print(f'{frequency / length:5.2f}', *comparison)
            print()

test(10_000, 1)
test(100_000, 10)
test(1_000_000, 1)
test(10_000_000, 1)

原回答

是的,我经常使用/指出两种更简单的排序可能比一种更复杂的排序更快。以下是更多版本的基准测试:

length = 10,000 个元素处,multikey 花费的时间大约是 multisort:

的 1.16 倍
multisort           12.09 ms  12.21 ms  12.32 ms  (no previous result)
multikey            14.13 ms  14.14 ms  14.14 ms  (same as previous result)
multisort_tupled    15.40 ms  15.61 ms  15.70 ms  (same as previous result)
multisort_inplaced  11.46 ms  11.49 ms  11.49 ms  (same as previous result)

length = 100,000,用了大约1.43倍的时间:

length = 100,000
multisort           0.162 s  0.164 s  0.164 s  (no previous result)
multikey            0.231 s  0.233 s  0.237 s  (same as previous result)
multisort_tupled    0.222 s  0.225 s  0.227 s  (same as previous result)
multisort_inplaced  0.156 s  0.157 s  0.158 s  (same as previous result)

length = 1,000,000,花费了大约 2.15 倍的时间:

multisort           1.894 s  1.897 s  1.898 s  (no previous result)
multikey            4.026 s  4.060 s  4.163 s  (same as previous result)
multisort_tupled    2.734 s  2.765 s  2.771 s  (same as previous result)
multisort_inplaced  1.840 s  1.841 s  1.863 s  (same as previous result)

我看到的原因:

  • 构建元组需要额外的时间,比较事物的元组比只比较那些事物要慢。请参阅 multisort_tupled 的时代,在其中的两种类型中,我将每个真实键包装在一个单值元组中。这让它变慢了。
  • 对于大数据,cpu/memory 缓存发挥更重要的作用。每个元组有 two 东西的额外元组是内存中只有其中一个东西的对象数量的三倍。导致更多缓存未命中。如果你真的很大,甚至可以(更多)交换文件使用。
  • 元组已注册到引用循环垃圾收集器,。我们这里不产生引用循环,所以我们可以在排序过程中禁用它。加快一点。 编辑:我仍然认为它至少应该稍微快一点,但是随着更多的测试,我在这方面遇到了冲突,所以把它去掉了。

顺便说一句,请注意我的 multisort_inplaced 还是快了一点。如果您进行多种排序,那么使用 sorted 进行所有排序是没有意义的。在第一个之后,只需使用 inplace sort。没有理由创建更多新列表,这需要 time/memory 用于列表,并且需要时间来更新所有列表元素的引用计数。

基准代码(Try it online!):

import random
from time import perf_counter as timer

def multisort(l):
    tmp = sorted(l, key=lambda x: x[::2])
    return sorted(tmp, key=lambda x: ''.join(sorted(x)))

def multikey(l):
    return sorted(l, key=lambda x: (''.join(sorted(x)), x[::2]))

def multisort_tupled(l):
    tmp = sorted(l, key=lambda x: x[::2])
    return sorted(tmp, key=lambda x: (''.join(sorted(x)),))

def multisort_inplaced(l):
    tmp = sorted(l, key=lambda x: x[::2])
    tmp.sort(key=lambda x: ''.join(sorted(x)))
    return tmp

funcs = [
    multisort,
    multikey,
    multisort_tupled,
    multisort_inplaced,
]

def test(length, rounds):
    print(f'{length = :,}')

    largest = 1000000
    lst = list(map(str, random.choices(range(largest + 1), k=length)))

    prev = None
    for func in funcs:
        times = []
        for _ in range(5):
            time = 0
            for _ in range(rounds):
                copy = lst.copy()
                start = timer()
                result = func(copy)
                time += timer() - start
            times.append(time / rounds)
        print('%-19s' % func.__name__,
              *('%5.2f ms ' % (t * 1e3) for t in sorted(times)[:3]),
              # *('%5.3f s ' % t for t in sorted(times)[:3]),
              '(%s previous result)' % ('no' if prev is None else 'same as' if result == prev else 'differs from'))
        prev = result

test(10_000, 10)
# test(100_000, 10)
# test(1_000_000, 1)

这是计时的第三种方式:

start = time()
l3 = sorted(lst, key=lambda x: (''.join(sorted(x)) + "/" + x[::2]))
t3 = time() - start

并将最后一行扩展为

assert l1 == l2 == l3

这使用单个字符串作为键,但将您视为“主要”和“次要”键的两个字符串键组合在一起。注意:

>>> chr(ord("0") - 1)
'/'

这就是为什么可以组合这两个键的原因 - 它们由一个 ASCII 字符分隔,该字符比较“小于”任何 ASCII 数字(当然,这完全特定于您使用的键类型)。

对我来说,这通常 multisort(),使用您发布的精确程序。

prepare=3.628943920135498 multisort=15.646344423294067 multikey=34.255955934524536 slowdown=2.1893903782103075 onekey=15.11461067199707

我相信在现代 CPython 发行版的末尾简要解释了“为什么”的主要原因 Objects/listsort.txt:

As noted above, even the simplest Python comparison triggers a large pile of C-level pointer dereferences, conditionals, and function calls. This can be partially mitigated by pre-scanning the data to determine whether the data is homogeneous with respect to type. If so, it is sometimes possible to substitute faster type-specific comparisons for the slower, generic PyObject_RichCompareBool.

当有一个字符串用作键时,此预排序扫描推断出列表中的所有键实际上都是字符串,因此计算出 哪个[=56] 的所有运行时开销=] 可以跳过要调用的比较函数:排序总是可以调用特定于字符串的比较函数,而不是通用的(而且明显更昂贵)PyObject_RichCompareBool.

multisort() 也受益于该优化。

但是 multikey() 并没有,很多。预排序扫描看到所有的键都是元组,但是元组比较函数本身不能假设元组元素的类型:每次调用它时都必须求助于 PyObject_RichCompareBool 。 (注意:正如评论中提到的那样,它并不是那么简单:一些优化仍然是利用键都是元组来完成的,但它并不总是有回报,并且在最好的效果较差 - 请参阅下一节以获得更清晰的证据。)

重点

测试用例中发生了很多事情,这导致需要付出更大的努力来解释越来越小的区别。

因此,为了查看类型同质性优化的效果,让我们将事情简化很多:根本没有 key 函数。像这样:

from random import random, seed
from time import time

length = 10000000
seed(1234567891)
xs = [random() for _ in range(length)]

ys = xs[:]
start = time()
ys.sort()
e1 = time() - start

ys = [(x,) for x in xs]
start = time()
ys.sort()
e2 = time() - start

ys = [[x] for x in xs]
start = time()
ys.sort()
e3 = time() - start
print(e1, e2, e3)

这是我盒子上的典型输出:

3.1991195678710938 12.756590843200684 26.31903386116028

所以目前为止直接对浮点数进行排序是最快的。将浮点数固定在 1 元组中已经非常有害,但优化仍然提供了非常显着的好处:将浮点数固定在单例列表中再次花费两倍的时间。在最后一种情况下(并且仅在最后一种情况下),总是调用 PyObject_RichCompareBool