仅当元素非零时如何进行 AVX-512 整数递增

How to do AVX-512 integer increment only if element is non zero

当且仅当元素的值不为零时,我才必须向 AVX 寄存器的元素添加值。下面是我的代码,但似乎我不得不去解决很多额外的麻烦,并且应该有更好的方法来做到这一点。注释掉的循环是我想做的事情的普通 C++ 表达式。

#include <immintrin.h>
#include <array>
#include <vector>
#include <iostream>
#include <bitset>
#include <cmath>
#include <chrono>

uint32_t make_bit_mask(bool T0 = 0, bool T1 = 0, bool T2 = 0, bool T3 = 0,
    bool T4 = 0, bool T5 = 0, bool T6 = 0, bool T7 = 0,
    bool T8 = 0, bool T9 = 0, bool T10 = 0, bool T11 = 0,
    bool T12 = 0, bool T13 = 0, bool T14 = 0, bool T15 = 0,
    bool T16 = 0, bool T17 = 0, bool T18 = 0, bool T19 = 0,
    bool T20 = 0, bool T21 = 0, bool T22 = 0, bool T23 = 0,
    bool T24 = 0, bool T25 = 0, bool T26 = 0, bool T27 = 0,
    bool T28 = 0, bool T29 = 0, bool T30 = 0, bool T31 = 0)
{
    return  ((T0 << 0) | (T1 << 1) | (T2 << 2) | (T3 << 3) |
        (T4 << 4) | (T5 << 5) | (T6 << 6) | (T7 << 7) |
        (T8 << 8) | (T9 << 9) | (T10 << 10) | (T11 << 11) |
        (T12 << 12) | (T13 << 13) | (T14 << 14) | (T15 << 15) |
        (T16 << 16) | (T17 << 17) | (T18 << 18) | (T19 << 19) |
        (T20 << 20) | (T21 << 21) | (T22 << 22) | (T23 << 23) |
        (T24 << 24) | (T25 << 25) | (T26 << 26) | (T27 << 27) |
        (T28 << 28) | (T29 << 29) | (T30 << 30) | (T31 << 31));
}

int main()
{
    std::vector<uint16_t> testValues{0};
    testValues.resize(65'536);
    for (size_t i{ 0 }; i < testValues.size(); i += 4)
    {
        testValues[i] = static_cast<uint16_t>(i);
    }
    auto start{ std::chrono::high_resolution_clock::now() };
    auto oneRegister{ _mm512_set1_epi16(1) };
    for (size_t i{ 0 }; i < testValues.size(); i += 32)
    {
        uint32_t loadMask{ make_bit_mask(testValues[i],testValues[i + 1],testValues[i + 2],testValues[i + 3],
                                            testValues[i + 4],testValues[i + 5],testValues[i + 6],testValues[i + 7],
                                            testValues[i + 8],testValues[i + 9],testValues[i + 10],testValues[i + 11],
                                            testValues[i + 12],testValues[i + 13],testValues[i + 14],testValues[i + 15],
                                            testValues[i + 16],testValues[i + 17],testValues[i + 18],testValues[i + 19],
                                            testValues[i + 20],testValues[i + 21],testValues[i + 22],testValues[i + 23],
                                            testValues[i + 24],testValues[i + 25],testValues[i + 26],testValues[i + 27],
                                            testValues[i + 28],testValues[i + 29],testValues[i + 30],testValues[i + 31])
        };
        _mm512_storeu_epi16(&testValues[i], _mm512_mask_add_epi16(_mm512_loadu_epi16(&testValues[i]),
            (__mmask32)loadMask, _mm512_loadu_epi16(&testValues[i]), oneRegister ));
    }
    /*
    for (auto& iter : testValues)
    {
        if (iter)
            iter += 1;
    }
    */
    auto end{ std::chrono::high_resolution_clock::now() };
    std::cout << "Summation took: " << std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count() << std::endl;
    return 0;
}

是的,有一个更简单的方法。简单地说,C++ 17 或更高版本中有二进制文字。如果你能用那个?

尝试这样的事情:

int main()
{
    uint32_t value = 0b1010'0100'0001'1100'0100'0011'1010'0010;
    //do whatever with value...
}

所以你想递增向量中的每个非零元素? (您实际上并不是在对一个向量中的元素求和,只是在进行垂直加法)。

根据元素非零,听起来您的真正问题是将整数数组转换为掩码。 AVX-512 对此有说明,例如与 0 比较,或者更具体地说 vptestmw k1, zmm,zmm 根据每个非零元素创建掩码。 (v & v 就是 v,所以你可以传递相同的操作数两次来绕过 AND 运算。

  __m512i v = _mm512_loadu_si512(&testValues[i]);
  __mmask32 nonzeros = _mm512_test_epi16_mask(v,v);
  v = _mm512_mask_sub_epi16(v, nonzeros, v, _mm512_set1_epi16(-1));  // set1(-1) is cheaper than 1

或者对于其他值,

 v = _mm512_mask_add_epi16(v, nonzeros, v, _mm512_set1_epi16( increment ));

在函数中,gcc -O2 -march=skylake-avx512 compiles it like this:

foo(unsigned short const*):
        vmovdqu64       zmm0, ZMMWORD PTR [rdi]
        vpternlogd      zmm1, zmm1, zmm1, 0xFF       # set1(-1)
        vptestmw        k1, zmm0, zmm0
        vpsubw  zmm0{k1}, zmm0, zmm1
        ret

set1(-1) 将被编译器提升到循环之外。


有趣的事实:clang 会为您将 add(v, set1(1)) 转换为 sub(v, set1(-1)),但 GCC 错过了优化。


如果没有 AVX-512(仅 AVX2 或 SSE2),您可以通过比较创建 0-1 向量。不幸的是,在 AVX-512 之前我们只有 cmpeqcmpgt,没有 cmpne,所以我们需要反转 0 / -1

  __m256i v = _mm256_loadu_si256((const __m256i*)&testValues[i]);
  __m256i nonzeros = _mm256_cmpeq_epi16(v, _mm256_setzero_si256());
  nonzeros = _mm256_xor_si256(nonzeros, _mm256_set1_epi32(-1)); 
          // or add 1 to turn 0->1 and -1->0 to use add()
  v = _mm256_sub_epi16(v, nonzeros);

对于任意常量,您可以_mm256_andnot_si256(nonzeros, _mm256_set1_epi16( increment ))创建一个矢量,v=0 元素为 0,非零元素为 increment