测试 AVX 寄存器是否包含一些相等的整数

Testing whether AVX register contains some equal integer numbers

考虑一个包含四个 64 位整数的 256 位寄存器。 在 AVX/AVX2 中是否可以有效地测试其中一些整数是否相等?

例如:

a) {43, 17, 25, 8}:结果必须是false,因为4个数字中没有2个是相等的。

b) {47, 17, 23, 17}:结果必须是'true',因为数字17在AVX向量寄存器中出现了2次。

如果可能的话,我想在 C++ 中执行此操作,但如果需要,我可以转到汇编。

使用 AVX512 (AVX512VL + AVX512CD),您将使用专为此目的设计的 VPCONFLICTQ


对于AVX2:

通过减少冗余比较减少了一些操作:

int test1(__m256i x)
{
    __m256i x0 = _mm256_permute4x64_epi64(x, 0x4B);
    // 1 0 2 3
    // 3 2 1 0
    __m256i e0 = _mm256_cmpeq_epi64(x0, x);
    __m256i x1 = _mm256_shuffle_epi32(x, 0x4E);
    // 2 3 0 1
    // 3 2 1 0
    __m256i e1 = _mm256_cmpeq_epi64(x1, x);
    __m256i t = _mm256_or_si256(e0, e1);
    return !_mm256_testz_si256(t, _mm256_set1_epi32(-1));
}

之前:

一个简单的 "compare everything with everything" 方法可以用于一些随机播放,像这样(未测试):

int hasDupe(__m256i x)
{
    __m256i x1 = _mm256_shuffle_epi32(x, 0x4E);
    __m256i x2 = _mm256_permute4x64_epi64(x, 0x4E);
    __m256i x3 = _mm256_shuffle_epi32(x2, 0x4E);
    // 2 3 0 1
    // 3 2 1 0
    __m256i e0 = _mm256_cmpeq_epi64(x1, x);
    // 1 0 3 2
    // 3 2 1 0
    __m256i e1 = _mm256_cmpeq_epi64(x2, x);
    // 0 1 2 3
    // 3 2 1 0
    __m256i e2 = _mm256_cmpeq_epi64(x3, x);
    __m256i t0 = _mm256_or_si256(_mm256_or_si256(e0, e1), e2);
    return !_mm256_testz_si256(t0, _mm256_set1_epi32(-1));
}

GCC 7 将其编译为合理的代码,但 Clang 确实做了一些奇怪的事情。似乎认为 vpor 没有 256 位版本(它完全有)。在这种情况下,将 OR 更改为加法会做大致相同的事情(将几个 -1 加在一起不会为零)并且不会导致 Clang 出现问题(也未测试):

int hasDupe(__m256i x)
{
    __m256i x1 = _mm256_shuffle_epi32(x, 0x4E);
    __m256i x2 = _mm256_permute4x64_epi64(x, 0x4E);
    __m256i x3 = _mm256_shuffle_epi32(x2, 0x4E);
    // 2 3 0 1
    // 3 2 1 0
    __m256i e0 = _mm256_cmpeq_epi64(x1, x);
    // 1 0 3 2
    // 3 2 1 0
    __m256i e1 = _mm256_cmpeq_epi64(x2, x);
    // 0 1 2 3
    // 3 2 1 0
    __m256i e2 = _mm256_cmpeq_epi64(x3, x);
    // "OR" results, workaround for Clang being weird
    __m256i t0 = _mm256_add_epi64(_mm256_add_epi64(e0, e1), e2);
    return !_mm256_testz_si256(t0, _mm256_set1_epi32(-1));
}