使用 AVX-512 收集/分散 16 位整数

Gather / Scatter 16-bit integers using AVX-512

我一直在尝试弄清楚我们应该如何使用 AVX512 中的分散指令来分散 16 位整数。我所拥有的是 8 x 16 位整数存储在 __m256i 的每个 32 位整数中。我会使用 _mm512_i32extscatter_epi32 的 256 位等价物,向下转换 _MM_DOWNCONV_EPI32_UINT16,但没有这样的指令,向下转换在 AVX512 上不起作用。

我的理解是这样...我们必须进行 32 位读取和写入,并且我们必须小心让两个相邻的 16 位写入相互破坏(如果索引中的相同索引列出两次然后我不需要担心哪个先发生)。所以我们必须使用冲突收集分散循环。在循环中,我们必须在 32 位整数地址上发生冲突,或者将 16 位索引左移 1 并用作等效 32 位数组的索引(将 16 位数组转换为 32 位数组的等效项)数组,然后将索引除以 2)。然后我们需要读取一个 32 位整数,并根据 16 位数组的原始索引是奇数还是偶数来更改高 16 位或低 16 位。

这就是我得到的:

  1. 判断索引是奇数还是偶数,并相应地设置2位掩码01或10,形成8个整数的16位掩码。

  2. 通过将低16位复制到高16位,将16位整数转换为32位整数

  3. 通过右移一位将索引变成16位整数数组变成32位索引数组的索引

  4. 使用带掩码的冲突循环

  5. 屏蔽收集 32 位整数

  6. 使用_mm256_mask_blend_epi16选择是否改变刚刚读取的32位整数的高16位或低16位(使用(1)中的掩码)。

  7. 蒙面-散回记忆

  8. 重复,直到未写入的32位整数地址没有冲突。

请问有更快(或更简单)的方法吗?是的,我知道,个人写入速度更快 - 但这是关于如何使用 AVX-512 来完成的。

代码如下:

void scatter(uint16_t *array, __m256i vindex, __m256i a)
    {
    __mmask16 odd = _mm256_test_epi16_mask(vindex, _mm256_set1_epi32(1));
    __mmask16 even = ~odd & 0x5555;
    __mmask16 odd_even = odd << 1 | even;

    __m256i data = _mm256_mask_blend_epi16(0x5555, _mm256_bslli_epi128(a, 2), a);

    __m256i word_locations = _mm256_srli_epi32(vindex, 1);
    __mmask8 unwritten = 0xFF;
    do
        {
        __m256i conflict = _mm256_maskz_conflict_epi32 (unwritten, word_locations);
        conflict = _mm256_and_si256(_mm256_set1_epi32(unwritten), conflict);
        __mmask8 mask = unwritten & _mm256_testn_epi32_mask(conflict, _mm256_set1_epi32(0xFFFF'FFFF));

        __m256i was = _mm256_mmask_i32gather_epi32(_mm256_setzero_si256(), mask, word_locations, array, 4);
        __m256i send = _mm256_mask_blend_epi16(odd_even, was, data);
        _mm256_mask_i32scatter_epi32(array, mask, word_locations, send, 4);

        unwritten ^= mask;
        }
    while (unwritten != 0);
    }

如果读取 from/write 到最后一个索引后的两个字节是安全的,这也应该有效:

void scatter2(uint16_t *array, __m256i vindex, __m256i a) {
  __mmask8 odd = _mm256_test_epi32_mask(vindex, _mm256_set1_epi32(1));

  int32_t* arr32 = (int32_t*)array;
  __m256i was_odd = _mm256_i32gather_epi32(arr32, vindex, 2);

  __m256i data_even = _mm256_mask_blend_epi16(0x5555, was_odd, a);
  _mm256_mask_i32scatter_epi32(array, ~odd, vindex, data_even, 2);
  __m256i was_even = _mm256_i32gather_epi32(arr32, vindex, 2);

  __m256i data_odd = _mm256_mask_blend_epi16(0x5555, was_even, a);
  _mm256_mask_i32scatter_epi32(array, odd, vindex, data_odd, 2);
}

如果您可以保证 vindex 中的索引正在增加(或者至少对于 vindex [=13= 中任何部分冲突的 {i, i+1} ] 出现在 i) 之后,您可能可以通过一次聚集+混合+分散来逃脱。此外,使用掩码收集可能会有好处(即,每次只收集您接下来要覆盖的元素)——我不确定这是否会对吞吐量产生影响。最后,_mm256_mask_blend_epi16 实际上可以用简单的 _mm256_blend_epi16.

代替