为什么 SSE4.2 cmpstr 比常规代码慢?

Why is SSE4.2 cmpstr slower than regular code?

我正在尝试验证一个字符串,该字符串必须只包含 ASCII 可见字符、白色 space 和 \t.

但在大多数 CPU 上,ASCII table 查找似乎比带 _SIDD_CMP_RANGES 的 _mm_cmpestri 指令更快。 我已经在 i5-2410M、i7-3720QM、i7-5600U 和未知类型的 KVM 虚拟化 Xeon 上对其进行了测试,只有最后一个是矢量化版本更快。

我的测试代码在这里:

#include <stdio.h>
#include <string.h>
#include <inttypes.h>
#include <sys/time.h>
#include <sys/mman.h>
#include <immintrin.h>
#include <stdalign.h>
#include <stdlib.h>

#define MIN(a,b) (((a)<(b))?(a):(b))

#define ALIGNED16 alignas(16)

#define MEASURE(msg,stmt) { \
    struct timeval tv; \
    gettimeofday(&tv, NULL); \
    uint64_t us1 = tv.tv_sec * (uint64_t)1000000 + tv.tv_usec; \
    stmt; \
    gettimeofday(&tv, NULL); \
    uint64_t us2 = tv.tv_sec * (uint64_t)1000000 + tv.tv_usec; \
    printf("%-20s - %.4fms\n", msg, ((double)us2 - us1) / 1000); \
}

// Character table
#define VWSCHAR(c)  (vis_ws_chars[(unsigned char)(c)])   // Visible characters and white space
#define YES     1,
#define NO      0,
#define YES16   YES YES YES YES YES YES YES YES YES YES YES YES YES YES YES YES
#define NO16    NO NO NO NO NO NO NO NO NO NO NO NO NO NO NO NO
#define NO128   NO16 NO16 NO16 NO16 NO16 NO16 NO16 NO16

// Visible ASCII characters with space and tab
ALIGNED16 static const int vis_ws_chars[256] = {
// NUL SOH STX ETX EOT ENQ ACK BEL BS  HT  LF  VT  FF  CR  SO  SI
   NO  NO  NO  NO  NO  NO  NO  NO  NO  YES NO  NO  NO  NO  NO  NO
// DLE DC1 DC2 DC3 DC4 NAK SYN ETB CAN EM  SUB ESC FS  GS  RS  US
   NO16
// SP  !   "   #   $   %   &   '   (   )   *   +   ,   -   .   /
// 0   1   2   3   4   5   6   7   8   9   :   ;   <   =   >   ?
// @   A   B   C   D   E   F   G   H   I   J   K   L   M   N   O
// P   Q   R   S   T   U   V   W   X   Y   Z   [   \   ]   ^   _
// `   a   b   c   d   e   f   g   h   i   j   k   l   m   n   o
   YES16 YES16 YES16 YES16 YES16
// p   q   r   s   t   u   v   w   x   y   z   {   |   }   ~   DEL
   YES YES YES YES YES YES YES YES YES YES YES YES YES YES YES NO
// Non-ASCII characters
   NO128
};

size_t search_logic(const char* data, size_t len) {
    __m128i ht = _mm_set1_epi8('\t');
    //__m128i del = _mm_set1_epi8(0x7f);
    __m128i td = _mm_set1_epi8('~');
    __m128i sp_m1 = _mm_set1_epi8(' ' - 1);
    size_t i = 0;
    while (len - i >= 16) {
        __m128i c = _mm_loadu_si128((const __m128i *) (data + i));
        // (!((c < del) && (c >= sp)) && (c != ht)) == 0
        //if(!_mm_testc_si128(_mm_and_si128(_mm_cmpgt_epi8(c, sp_m1), _mm_cmplt_epi8(c, del)), _mm_xor_si128(c, ht)))
            //break;
        // !(c == del) && ((c == ht) || (c >= sp)) == 1
        //if(!_mm_test_all_ones(_mm_andnot_si128(_mm_cmpeq_epi8(c, del), _mm_or_si128(_mm_cmpeq_epi8(c, ht), _mm_cmpgt_epi8(c, sp_m1)))))
            //break;
        // (((c != ht) && (c >= sp)) && (c > td)) == 0
        if(!_mm_test_all_zeros(_mm_and_si128(_mm_xor_si128(c, ht), _mm_cmpgt_epi8(c, sp_m1)), _mm_cmpgt_epi8(c, td)))
            break;
        i += 16;
    }
    // Check last 15 bytes
    for (; i < len; ++i) {
        if (!VWSCHAR(data[i])) {
            break;
        }
    }
    return i;
}

size_t search_table(const char* data, size_t len)
{
    // Search non-matching character via table lookups
    size_t i = 0;
    while (len - i >= 16) {
        if (!VWSCHAR(data[i + 0])) break;
        if (!VWSCHAR(data[i + 1])) break;
        if (!VWSCHAR(data[i + 2])) break;
        if (!VWSCHAR(data[i + 3])) break;
        if (!VWSCHAR(data[i + 4])) break;
        if (!VWSCHAR(data[i + 5])) break;
        if (!VWSCHAR(data[i + 6])) break;
        if (!VWSCHAR(data[i + 7])) break;
        if (!VWSCHAR(data[i + 8])) break;
        if (!VWSCHAR(data[i + 9])) break;
        if (!VWSCHAR(data[i + 10])) break;
        if (!VWSCHAR(data[i + 11])) break;
        if (!VWSCHAR(data[i + 12])) break;
        if (!VWSCHAR(data[i + 13])) break;
        if (!VWSCHAR(data[i + 14])) break;
        if (!VWSCHAR(data[i + 15])) break;
        i += 16;
    }
    // Check last 15 bytes
    for (; i < len; ++i) {
        if (!VWSCHAR(data[i])) {
            break;
        }
    }
    return i;
}

size_t search_sse4cmpstr(const char* data, size_t len)
{
    static const char legal_ranges[16] = {
        '\t', '\t',
        ' ',  '~',
    };
    __m128i v1 = _mm_loadu_si128((const __m128i*)legal_ranges);
    size_t i = 0;
    while (len - i >= 16) {
        __m128i v2 = _mm_loadu_si128((const __m128i*)(data + i));
        unsigned consumed = _mm_cmpestri(v1, 4, v2, 16, _SIDD_LEAST_SIGNIFICANT|_SIDD_CMP_RANGES|_SIDD_UBYTE_OPS|_SIDD_NEGATIVE_POLARITY);
        i += consumed;
        if (consumed < 16) {
            return i;
        }
    }
    // Check last 15 bytes
    for (; i < len; ++i) {
        if (!VWSCHAR(data[i])) {
            break;
        }
    }
    return i;
}

size_t search_sse4cmpstr_implicit(const char* data, size_t len)
{
    static const char legal_ranges[16] = {
        '\t', '\t',
        ' ',  '~',
    };
    __m128i v1 = _mm_loadu_si128((const __m128i*)legal_ranges);
    size_t i = 0;
    while (len - i >= 16) {
        __m128i v2 = _mm_loadu_si128((const __m128i*)(data + i));
        unsigned consumed = _mm_cmpistri(v1, v2, _SIDD_LEAST_SIGNIFICANT|_SIDD_CMP_RANGES|_SIDD_UBYTE_OPS|_SIDD_NEGATIVE_POLARITY);
        i += consumed;
        if (consumed < 16) {
            return i;
        }
    }
    // Check last 15 bytes
    for (; i < len; ++i) {
        if (!VWSCHAR(data[i])) {
            break;
        }
    }
    return i;
}

int main()
{
    printf("Setting up 1GB of data...\n");
    size_t len = 1024 * 1024 * 1024 + 3;
    char* data = (char*)mmap(NULL, len, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS|MAP_POPULATE, -1, 0); // Aligned
    srand(0);
    for (size_t i = 0; i < len; ++i) {
        const char v = rand() % 96;
        data[i] = v == 95 ? '\t' : ' ' + v;
    }
    size_t end = len - 2;
    data[end] = '\n'; // Illegal character to be found

    MEASURE("table lookup", {
        size_t i = search_table(data, len);
        if (i != end) printf("INCORRECT RESULT: %zu instead of %zu", i, end);
    });
    MEASURE("cmpestr ranges", {
        size_t i = search_sse4cmpstr(data, len);
        if (i != end) printf("INCORRECT RESULT: %zu instead of %zu", i, end);
    });
    MEASURE("cmpistr ranges", {
        size_t i = search_sse4cmpstr_implicit(data, len);
        if (i != end) printf("INCORRECT RESULT: %zu instead of %zu", i, end);
    });
    MEASURE("logic ranges", {
        size_t i = search_logic(data, len);
        if (i != end) printf("INCORRECT RESULT: %zu instead of %zu", i, end);
    });
}

gcc -O3 -march=native -pedantic -Wall -Wextra main2.cpp 编译得到这些结果:

Setting up 1GB of data...
table lookup         - 476.4820ms
cmpestr ranges       - 519.3350ms
cmpistr ranges       - 497.5770ms
logic ranges         - 153.2650ms

我还检查了汇编输出,search_sse4cmpstr 使用 vpcmpestri 而 search_table 是非矢量化的。

我是不是用错了?或者为什么这条指令存在?

编辑: 正如评论中指出的那样,cmpistr(参数较少的隐式长度指令)比 cmpestr 稍快,有时比 table 查找快。

但是,SSE2 按位和整数运算似乎更快。

编辑2 Peter Cordes 找到了正确的答案。 我已经在一个新的答案中添加了修改后的程序,所以如果你对 cmpstr 感兴趣,请看这个。

请勿使用以上代码!

该代码对前一个向量具有 i 的不必要依赖性,瓶颈在 pcmpestri + L1d 加载使用延迟大约 12 + 5 个周期。 (https://agner.org/optimize/ and https://uops.info/) 所以是的,不幸的是你用错了。

如果你写的它类似于你的标量循环,做 i+=16 并且只是检查 pcmpestri 结果作为循环退出条件,你会在它的 吞吐量上遇到瓶颈 在您的 Sandybridge 系列 CPU 上每 4 个时钟有 1 个向量。 (特别是 SnB 和 IvB)。

或者,如果您的输入可以使用 pcmpistri,那就没那么糟糕了,在 Sandybridge 系列上每 3 个时钟可以达到 1 个。

起初我没有注意到这个问题,因为我没想到循环会这样写,而且 asm 循环中还有其他混乱。 :/ 我花了很多时间用 perf 进行分析,以确保它不是来自我的 Skylake CPU 上的微编码(8 uop)指令的前端瓶颈。查看现在存档的评论。

吞吐量瓶颈会让您以大约 4 字节/周期的速度运行,而不是 4 字节/周期。 另一种方式约为 1(每个输入字节 2 个负载,而 Intel 因为 SnB 每个时钟可以执行 2 个负载)。所以加速了 4 倍。或者 Nehalem 上的 8 倍,负载吞吐量为 1/时钟。

巧合的是,延迟瓶颈大约是每个输入字节 1 个周期,与 table 查找大致相同。


此外,不要使用 len - i < 16; gcc 实际上计算了在循环内花费额外的 uops。一旦知道 len>=16,就使用 i < len-15。 (无符号类型使这变得棘手,因为它们在零处换行;您希望它编译成一个 cmp/jcc 来跳过循环,然后是一个 do{}while asm 循环结构。所以初始的 len>=16确实与正常循环条件分开。)


关于 pcmpestri 的其他有趣事实:

  • How much faster are SSE4.2 string instructions than SSE2 for memcmp?(速度较慢,尤其是使用 AVX2)
  • SSE42 & STTNI - PcmpEstrM is twice slower than PcmpIstrM, is it true? 是的,显式长度版本比隐式长度版本慢。显然,基于额外 2 个长度输入的屏蔽比在现有输入中扫描 0 字节更慢并且成本更高。
  • 性能不依赖于立即数的值。有一次我认为它确实如此,但那是 i 取决于结果,所以改变直接导致缓存行分裂,使循环延迟更糟。用 i+=16 循环重新测试显示没有效果。
  • 如果与 REX.W 前缀一起使用(在 RAX 和 RDX 而不是 EAX 和 EDX 中获取输入),它对英特尔来说要慢得多(根据 https://uops.info/),但没有内在的所以你不必担心编译器会那样做。

Or why does this instruction exist at all?

这些指令是在 Nehalem 中引入的。如果它们 "caught on" 并被广泛使用,英特尔可能已经计划让它们更快。对于短字符串 strcmp。但是如果没有故障抑制(对于可能跨入新页面的未对齐加载),如果不检查有关指针的内容,它们就很难使用。如果您无论如何都要进行检查,您不妨使用高效的 pcmpeqb/pmovmskb ,它的微指令更少。并且也许可以使用 pminub/pcmpeqb/pmovmskb -> bsf 在任一字符串中找到第一个零。也许有一个 SSE4.2 的用例用于 strcmp 的初始启动,但一旦开始就不会那么多了。

世界上大多数人关心的是 UTF-8,而不是 8 位字符集。由于 UTF-16 不再是固定宽度的(多亏了 32 位 Unicode),即使是宽字符的东西也很难用这些加速。

使用范围功能基本上需要手动矢量化,这对于只处理 ASCII 的东西来说是很多工作。

正如您所发现的,对于简单的情况,您可以使用 pcmpgtb 和布尔逻辑更快。使用 AVX2,您可以一次处理 32 个字节而不是 16 个字节,但是 vpcmpistri 没有 AVX2 版本,只有 16 字节指令的 AVX1 VEX 编码。

正如 Peter Cordes 指出的那样,问题是由对 cmpstr 输出的不必要依赖引起的。 这可以通过简单地重构这个循环来解决:

while (len - i >= 16) {
    __m128i v2 = _mm_loadu_si128((const __m128i*)(data + i));
    unsigned consumed = _mm_cmpistri(v1, v2, _SIDD_LEAST_SIGNIFICANT|_SIDD_CMP_RANGES|_SIDD_UBYTE_OPS|_SIDD_NEGATIVE_POLARITY);
    i += consumed;
    if (consumed < 16) {
        return i;
    }
}

进入那个:

if (len >= 16)
while (i <= len - 16) {
    __m128i v2 = _mm_loadu_si128((const __m128i*)(data + i));
    unsigned consumed = _mm_cmpistri(v1, v2, _SIDD_LEAST_SIGNIFICANT|_SIDD_CMP_RANGES|_SIDD_UBYTE_OPS|_SIDD_NEGATIVE_POLARITY);
    if (consumed < 16) {
        return i + consumed;
    }
    i += 16;
}

我的 i5-2410M 使用 gcc -pedantic -Wall -Wextra -O3 -march=native sse42cmpstr.c 编译的结果现在看起来好多了:

Setting up 1GB of data...
table                - 484.5900ms
cmpestr              - 231.9770ms
cmpistr              - 121.3510ms
logic                - 142.3700ms

现在 cmpistr 明显比 cmpestr 和 table 搜索都快,甚至超过 在我测试过的大多数 CPU 上手工制作的 SSE2 逻辑比较。

完整的测试代码在这里:

#include <stdio.h>
#include <inttypes.h>
#include <sys/time.h>
#include <sys/mman.h>
#include <immintrin.h>
#include <stdalign.h>

#define ALIGNED16 __attribute__((aligned(16)))

#define MEASURE(msg,stmt) { \
    struct timeval tv; \
    gettimeofday(&tv, NULL); \
    uint64_t us1 = tv.tv_sec * (uint64_t)1000000 + tv.tv_usec; \
    stmt; \
    gettimeofday(&tv, NULL); \
    uint64_t us2 = tv.tv_sec * (uint64_t)1000000 + tv.tv_usec; \
    printf("%-20s - %.4fms\n", msg, ((double)us2 - us1) / 1000); \
}

// Character table
#define VWSCHAR(c)  (vis_ws_chars[(unsigned char)(c)])   // Visible characters and white space
#define YES     1,
#define NO      0,
#define YES16   YES YES YES YES YES YES YES YES YES YES YES YES YES YES YES YES
#define NO16    NO NO NO NO NO NO NO NO NO NO NO NO NO NO NO NO
#define NO128   NO16 NO16 NO16 NO16 NO16 NO16 NO16 NO16

// Visible ASCII characters with space and tab
ALIGNED16 static const int vis_ws_chars[256] = {
// NUL SOH STX ETX EOT ENQ ACK BEL BS  HT  LF  VT  FF  CR  SO  SI
   NO  NO  NO  NO  NO  NO  NO  NO  NO  YES NO  NO  NO  NO  NO  NO
// DLE DC1 DC2 DC3 DC4 NAK SYN ETB CAN EM  SUB ESC FS  GS  RS  US
   NO16
// SP  !   "   #   $   %   &   '   (   )   *   +   ,   -   .   /
// 0   1   2   3   4   5   6   7   8   9   :   ;   <   =   >   ?
// @   A   B   C   D   E   F   G   H   I   J   K   L   M   N   O
// P   Q   R   S   T   U   V   W   X   Y   Z   [   \   ]   ^   _
// `   a   b   c   d   e   f   g   h   i   j   k   l   m   n   o
   YES16 YES16 YES16 YES16 YES16
// p   q   r   s   t   u   v   w   x   y   z   {   |   }   ~   DEL
   YES YES YES YES YES YES YES YES YES YES YES YES YES YES YES NO
// Non-ASCII characters
   NO128
};

// Search using the ASCII table above
size_t search_table(const char* data, size_t len)
{
    // Search non-matching character via table lookups
    size_t i = 0;
    if(len >= 16) {
        while (i <= len - 16) {
            if (!VWSCHAR(data[i + 0])) break;
            if (!VWSCHAR(data[i + 1])) break;
            if (!VWSCHAR(data[i + 2])) break;
            if (!VWSCHAR(data[i + 3])) break;
            if (!VWSCHAR(data[i + 4])) break;
            if (!VWSCHAR(data[i + 5])) break;
            if (!VWSCHAR(data[i + 6])) break;
            if (!VWSCHAR(data[i + 7])) break;
            if (!VWSCHAR(data[i + 8])) break;
            if (!VWSCHAR(data[i + 9])) break;
            if (!VWSCHAR(data[i + 10])) break;
            if (!VWSCHAR(data[i + 11])) break;
            if (!VWSCHAR(data[i + 12])) break;
            if (!VWSCHAR(data[i + 13])) break;
            if (!VWSCHAR(data[i + 14])) break;
            if (!VWSCHAR(data[i + 15])) break;
            i += 16;
        }
    }
    // Check last bytes
    for (; i < len; ++i) {
        if (!VWSCHAR(data[i])) {
            break;
        }
    }
    return i;
}

// Search using SSE4.2 cmpestri (explicit length)
size_t search_sse4cmpestr(const char* data, size_t len)
{
    ALIGNED16 static const char legal_ranges[16] = {
        '\t', '\t',
        ' ',  '~',
    };
    __m128i v1 = _mm_loadu_si128((const __m128i*) legal_ranges);
    size_t i = 0;
    if(len >= 16) {
        while (i <= len - 16) {
            __m128i v2 = _mm_loadu_si128((const __m128i*) (data + i));
            unsigned consumed = _mm_cmpestri(v1, 4, v2, 16, _SIDD_LEAST_SIGNIFICANT|_SIDD_CMP_RANGES|_SIDD_UBYTE_OPS|_SIDD_NEGATIVE_POLARITY);
            if (consumed < 16) {
                return i + consumed;
            }
            i += 16;
        }
    }
    for (; i < len; ++i) {
        if (!VWSCHAR(data[i])) {
            return i;
        }
    }
    return i;
}

// Search using SSE4.2 cmpistri (implicit length)
size_t search_sse4cmpistr(const char* data, size_t len)
{
    ALIGNED16 static const char legal_ranges[16] = {
        '\t', '\t',
        ' ',  '~',
    };
    __m128i v1 = _mm_loadu_si128((const __m128i*) legal_ranges);
    size_t i = 0;
    if (len >= 16) {
        while (i <= len - 16) {
            __m128i v2 = _mm_loadu_si128((const __m128i*)(data + i));
            unsigned consumed = _mm_cmpistri(v1, v2, _SIDD_LEAST_SIGNIFICANT|_SIDD_CMP_RANGES|_SIDD_UBYTE_OPS|_SIDD_NEGATIVE_POLARITY);
            if (consumed < 16) {
                return i + consumed;
            }
            i += 16;
        }
    }
    for (; i < len; ++i) {
        if (!VWSCHAR(data[i])) {
            return i;
        }
    }
    return i;
}

// Search using SSE2 logic instructions
size_t search_logic(const char* data, size_t len) {
    __m128i ht = _mm_set1_epi8('\t');
    //__m128i del = _mm_set1_epi8(0x7f);
    __m128i td = _mm_set1_epi8('~');
    __m128i sp_m1 = _mm_set1_epi8(' ' - 1);
    size_t i = 0;
    if(len >= 16) {
        while (len - 16 >= i) {
            __m128i c = _mm_loadu_si128((const __m128i *) (data + i));
            // (((c != ht) && (c >= sp)) && (c > td)) == 0
            if(!_mm_test_all_zeros(_mm_and_si128(_mm_xor_si128(c, ht), _mm_cmpgt_epi8(c, sp_m1)), _mm_cmpgt_epi8(c, td)))
                break;
            i += 16;
        }
    }
    // Check last bytes
    for (; i < len; ++i) {
        if (!VWSCHAR(data[i])) {
            break;
        }
    }
    return i;
}

int main()
{
    printf("Setting up 1GB of data...\n");
    size_t len = 1024 * 1024 * 1024 + 3;
    char* data = (char*)mmap(NULL, len, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS|MAP_POPULATE, -1, 0); // Aligned
    for (size_t i = 0; i < len; ++i) {
        const char v = i % 96;
        data[i] = v == 95 ? '\t' : ' ' + v;
    }
    size_t end = len - 2;
    data[end] = '\n'; // Illegal character to be found

    MEASURE("table", {
        size_t i = search_table(data, len);
        if (i != end) printf("INCORRECT RESULT: %u instead of %u\n", i, end);
    });
    MEASURE("cmpestr", {
        size_t i = search_sse4cmpestr(data, len);
        if (i != end) printf("INCORRECT RESULT: %u instead of %u\n", i, end);
    });
    MEASURE("cmpistr", {
        size_t i = search_sse4cmpistr(data, len);
        if (i != end) printf("INCORRECT RESULT: %u instead of %u\n", i, end);
    });
    MEASURE("logic", {
        size_t i = search_logic(data, len);
        if (i != end) printf("INCORRECT RESULT: %u instead of %u\n", i, end);
    });
}