Python字符串比较不短路?

Python string comparison doesn't short circuit?

通常的说法是,在检查密码或哈希值时,字符串比较必须在恒定时间内进行,因此建议避免使用a == b。 但是,我 运行 以下脚本和结果不支持 a==b 在第一个不同字符上短路的假设。

from time import perf_counter_ns
import random

def timed_cmp(a, b):
    start = perf_counter_ns()
    a == b
    end = perf_counter_ns()
    return end - start

def n_timed_cmp(n, a, b):
    "average time for a==b done n times"
    ts = [timed_cmp(a, b) for _ in range(n)]
    return sum(ts) / len(ts)

def check_cmp_time():
    random.seed(123)
    # generate a random string of n characters
    n = 2 ** 8
    s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])

    # generate a list of strings, which all differs from the original string
    # by one character, at a different position
    # only do that for the first 50 char, it's enough to get data
    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]

    timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]
    sorted_timed = sorted(timed, key=lambda t: t[1])

    # print the 10 fastest
    for x in sorted_timed[:10]:
        i, t = x
        print("{}\t{:3f}".format(i, t))

    print("---")
    i, t = timed[0]
    print("{}\t{:3f}".format(i, t))

    i, t = timed[1]
    print("{}\t{:3f}".format(i, t))

if __name__ == "__main__":
    check_cmp_time()

这是 运行 的结果,重新 运行 脚本给出的结果略有不同,但都不令人满意。

# ran with cpython 3.8.3

6   78.051700
1   78.203200
15  78.222700
14  78.384800
11  78.396300
12  78.441800
9   78.476900
13  78.519000
8   78.586200
3   78.631500
---
0   80.691100
1   78.203200

我原以为最快的比较是第一个不同字符位于字符串开头的位置,但这不是我得到的。 知道发生了什么事吗???

看,要知道为什么它不会短路,您必须进行一些挖掘。简单的答案当然是它不会短路,因为标准没有这样规定。但是您可能会想,“为什么实现不选择短路?当然,它必须更快!”。不完全是。

让我们来看看 cpython,原因很明显。查看 unicode_compare_eq function defined in unicodeobject.c

的代码
static int
unicode_compare_eq(PyObject *str1, PyObject *str2)
{
    int kind;
    void *data1, *data2;
    Py_ssize_t len;
    int cmp;

    len = PyUnicode_GET_LENGTH(str1);
    if (PyUnicode_GET_LENGTH(str2) != len)
        return 0;
    kind = PyUnicode_KIND(str1);
    if (PyUnicode_KIND(str2) != kind)
        return 0;
    data1 = PyUnicode_DATA(str1);
    data2 = PyUnicode_DATA(str2);

    cmp = memcmp(data1, data2, len * kind);
    return (cmp == 0);
}

(注意:这个函数实际上是在推断 str1str2 不是同一个对象之后调用的 - 如果它们是 - 那么这只是一个简单的 True 立即)

专门关注这一行-

cmp = memcmp(data1, data2, len * kind);

啊哈,我们又回到了另一个十字路口。 memcmp是否短路? C标准没有规定这样的要求。如 the opengroup docs and also in Section 7.24.4.1 of the C Standard Draft

所示

7.24.4.1 The memcmp function

Synopsis

#include <string.h>
int memcmp(const void *s1, const void *s2, size_t n);

Description

The memcmp function compares the first n characters of the object pointed to by s1 to the first n characters of the object pointed to by s2.

Returns

The memcmp function returns an integer greater than, equal to, or less than zero, accordingly as the object pointed to by s1 is greater than, equal to, or less than the object pointed to by s2.

大多数部分C实现(包括glibc)选择不短路。但为什么?我们是不是遗漏了什么,你为什么不短路?

因为他们使用的比较 不是 可能不像逐字节检查那样幼稚。该标准不要求对象逐字节进行比较。这就是优化的机会。

glibc 的作用是比较 unsigned long int 类型的元素,而不仅仅是 unsigned char 表示的单个字节。查看 implementation

幕后还有很多事情要做 - 讨论远远超出了这个问题的范围,毕竟这甚至没有被标记为 C 问题 ;)。虽然我发现 可能值得一看。但要知道,优化就在那里,只是形式与乍一看可能想到的方法大不相同。

编辑:修复了错误的函数 link

编辑:正如@Konrad Rudolph 所说,glibc memcmp 确实短路了。我被误导了。

有区别,只是在这么小的字符串上看不到。这是一个适用于您的代码的小补丁,所以我使用更长的字符串,并且通过将 A 放在一个位置来进行 10 次检查,从头到尾在原始字符串中均匀分布,我的意思是,像这样:

A_______________________________________________________________
______A_________________________________________________________
____________A___________________________________________________
__________________A_____________________________________________
________________________A_______________________________________
______________________________A_________________________________
____________________________________A___________________________
__________________________________________A_____________________
________________________________________________A_______________
______________________________________________________A_________
____________________________________________________________A___
@@ -15,13 +15,13 @@ def n_timed_cmp(n, a, b):
 def check_cmp_time():
     random.seed(123)
     # generate a random string of n characters
-    n = 2 ** 8
+    n = 2 ** 16
     s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])

     # generate a list of strings, which all differs from the original string
     # by one character, at a different position
     # only do that for the first 50 char, it's enough to get data
-    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]
+    diffs = [s[:i] + "A" + s[i+1:] for i in range(0, n, n // 10)]

     timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]
     sorted_timed = sorted(timed, key=lambda t: t[1])

你会得到:

0   122.621000
1   213.465700
2   380.214100
3   460.422000
5   694.278700
4   722.010000
7   894.630300
6   1020.722100
9   1149.473000
8   1341.754500
---
0   122.621000
1   213.465700

请注意,对于您的示例,只有 2**8 个字符,它已经很明显了,请应用此补丁:

@@ -21,7 +21,7 @@ def check_cmp_time():
     # generate a list of strings, which all differs from the original string
     # by one character, at a different position
     # only do that for the first 50 char, it's enough to get data
-    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]
+    diffs = [s[:i] + "A" + s[i+1:] for i in [0, n - 1]]
 
     timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]
     sorted_timed = sorted(timed, key=lambda t: t[1])

只保留两种极端情况(第一个字母变化与最后一个字母变化)你会得到:

$ python3 cmp.py
0   124.131800
1   135.566000

数字可能会有所不同,但大多数时候测试 0 比测试 1 快一点。

为了更精确地隔离修改了哪个字符,只要 memcmp 一个字符一个字符地执行它就可以,只要它不使用整数比较,通常是在最后一个字符未对齐时,或者在非常短的字符串,比如 8 个字符的字符串,正如我在这里演示的那样:

from time import perf_counter_ns
from statistics import median
import random


def check_cmp_time():
    random.seed(123)
    # generate a random string of n characters
    n = 8
    s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])

    # generate a list of strings, which all differs from the original string
    # by one character, at a different position
    # only do that for the first 50 char, it's enough to get data
    diffs = [s[:i] + "A" + s[i + 1 :] for i in range(n)]

    values = {x: [] for x in range(n)}
    for _ in range(10_000_000):
        for i, diff in enumerate(diffs):
            start = perf_counter_ns()
            s == diff
            values[i].append(perf_counter_ns() - start)

    timed = [[k, median(v)] for k, v in values.items()]
    sorted_timed = sorted(timed, key=lambda t: t[1])

    # print the 10 fastest
    for x in sorted_timed[:10]:
        i, t = x
        print("{}\t{:3f}".format(i, t))

    print("---")
    i, t = timed[0]
    print("{}\t{:3f}".format(i, t))

    i, t = timed[1]
    print("{}\t{:3f}".format(i, t))


if __name__ == "__main__":
    check_cmp_time()

这给了我:

1   221.000000
2   222.000000
3   223.000000
4   223.000000
5   223.000000
6   223.000000
7   223.000000
0   241.000000

差异如此之小,Python 和 perf_counter_ns 可能不再是这里的正确工具。