如何从 AVX 寄存器中获取数据?

How to get data out of AVX registers?

使用 MSVC 2013 和 AVX 1,我在寄存器中有 8 个浮点数:

__m256 foo = mm256_fmadd_ps(a,b,c);

现在我想为所有 8 个花车调用 inline void print(float) {...}。看起来 Intel AVX intrisics 会使这变得相当复杂:

print(_castu32_f32(_mm256_extract_epi32(foo, 0)));
print(_castu32_f32(_mm256_extract_epi32(foo, 1)));
print(_castu32_f32(_mm256_extract_epi32(foo, 2)));
// ...

但是 MSVC 甚至没有这两个内在函数。当然,我可以将值写回内存并从那里加载,但我怀疑在汇编级别不需要溢出寄存器。

红利问:我当然愿意写

for(int i = 0; i !=8; ++i) 
    print(_castu32_f32(_mm256_extract_epi32(foo, i)))

但 MSVC 不理解许多内部函数 需要 循环展开。如何在 __m256 foo 中的 8x32 浮点数上编写循环?

假设您只有 AVX(即没有 AVX2),那么您可以这样做:

float extract_float(const __m128 v, const int i)
{
    float x;
    _MM_EXTRACT_FLOAT(x, v, i);
    return x;
}

void print(const __m128 v)
{
    print(extract_float(v, 0));
    print(extract_float(v, 1));
    print(extract_float(v, 2));
    print(extract_float(v, 3));
}

void print(const __m256 v)
{
    print(_mm256_extractf128_ps(v, 0));
    print(_mm256_extractf128_ps(v, 1));
}

不过我想我可能只使用联合:

union U256f {
    __m256 v;
    float a[8];
};

void print(const __m256 v)
{
    const U256f u = { v };

    for (int i = 0; i < 8; ++i)
        print(u.a[i]);
}

注意:_mm256_fmadd_ps 不是 AVX1 的一部分。 FMA3 有自己的功能位,并且仅在 Intel 上与 Haswell 一起引入。 AMD 引入了带打桩机的 FMA3(AVX1+FMA4+FMA3,无 AVX2)。


在asm层面,如果想把8个32bit的元素放到整数寄存器中,先入栈再做标量加载其实更快。 pextrd 是 SnB 系列和 Bulldozer 系列的 2-uop 指令。 (以及不支持 AVX 的 Nehalem 和 Silvermont)。

唯一 CPU vextractf128 + 2xmovd + 6xpextrd 并不可怕的是 AMD Jaguar。 (便宜 pextrd,并且只有一个加载端口。)(参见 Agner Fog's insn tables

宽对齐的存储可以转发到重叠的窄负载。 (当然,你可以使用movd来获取低位元素,这样你就有了load port和ALU port uops的混合)。


当然,您似乎是通过使用整数提取然后将其转换回浮点数来提取 floats。这看起来很糟糕。

你真正需要的是每个 float 在它自己的 xmm 寄存器的低元素中。 vextractf128 显然是开始的方式,将元素 4 带到新的 xmm reg 的底部。那么6x AVX shufps 就可以轻松得到各半的其他三个元素。 (或者 movshdupmovhlps 的编码更短:没有立即字节)。

7 个洗牌微指令与 1 个存储微指令和 7 个加载微指令相比值得考虑,但如果您无论如何都要为函数调用溢出向量,则不值得考虑。


ABI 注意事项:

您在 Windows,其中 xmm6-15 被调用保留(只有 low128;ymm6-15 的上半部分被调用破坏)。这是以 vextractf128.

开头的另一个原因

在 SysV ABI 中,所有 xmm / ymm / zmm 寄存器都被调用破坏,因此每个 print() 函数都需要一个 spill/reload。唯一明智的做法是存储到内存并使用原始向量调用 print(即打印低元素,因为它将忽略寄存器的其余部分)。然后 movss xmm0, [rsp+4] 并在第二个元素上调用 print

将所有 8 个浮点数很好地解压缩到 8 个向量 reg 中对你没有好处,因为无论如何在第一次函数调用之前它们都必须单独溢出!

(未完成的答案。无论如何发布以防万一它对任何人有帮助,或者万一我回到它。通常,如果您需要与无法矢量化的标量接口,那么这样做也不错只需将向量存储到本地数组,然后一次重新加载一个元素。)

有关 asm 详细信息,请参阅我的其他答案。这个答案是关于 C++ 方面的事情。


void foo(__m256 v) {
    alignas(32) float vecbuf[8];   // 32-byte aligned array allows aligned store
                                   // avoiding the risk of cache-line splits
    _mm256_store_ps(vecbuf, v);

    float v0 = _mm_cvtss_f32(_mm256_castps256_ps128(v));  // the bottom of the register
    float v1 = vecbuf[1];
    float v2 = vecbuf[2];
    ...
   // or loop over vecbuf[i]
   // if you do need all 8 elements one at a time, this is a good way
}

或循环 vecbuf[i]。矢量存储可以转发到其元素之一的标量重新加载,因此这只会引入大约 6 个周期的延迟,并且可以同时进行多个重新加载。 (因此对于具有 2/clock 负载吞吐量的现代 CPU 的吞吐量非常有用。)

请注意,我避免重新加载低元素;寄存器中向量的低位元素已经 标量 float_mm_cvtss_f32( _mm256_castps256_ps128(v) ) 只是让编译器的类型系统满意的方法;它编译为零 asm 指令,因此它实际上是免费的(除非错过优化错误)。 (参见 Intel's intrinsics guide)。 XMM 寄存器是相应 YMM 寄存器的低 128 位,标量 float / double 是 XMM 寄存器的低 32 位或 64 位。 (上半部分的垃圾无所谓。)

投射第一个让 OoO exec 在等待其余部分到达时有事可做。您可能会考虑在低位 128 上使用 vunpckhpsvmovhlps 进行洗牌以获得第二个元素,这样您就可以快速准备好 2 个元素,如果这有助于填补延迟气泡的话。

在 GNU C/C++ 中,您可以使用 v[1] 甚至像 v[i] 这样的变量索引来索引向量类型,例如数组。编译器将在 shuffle 或 store/reload.

之间进行选择

但这不能移植到 MSVC,MSVC 根据与一些命名成员的联合定义 __m256

存储到数组并重新加载是可移植的,编译器有时甚至可以将其优化为 shuffle。(如果您不想那样,请检查生成的 asm。)

例如clang 将一个刚刚 returns vecbuf[1] 的函数优化成一个简单的 vshufps。 https://godbolt.org/z/tHJH_V


如果你真的想把一个向量的所有元素加起来成为一个标量总数,shuffle 和 SIMD addFastest way to do horizontal float vector sum on x86

(对于单个向量的元素的乘法、最小值、最大值或其他关联归约相同。当然,如果您有多个向量,请对一个向量执行垂直操作,例如 _mm256_add_ps(v1,v2)


使用 Agner Fog's Vector Class Library,他的包装器 类 重载 operator[] 完全按照您期望的方式工作,即使对于非常量参数也是如此。这通常会编译成 store/reload,但它使用 C++ 编写代码变得容易。启用优化后,您可能会获得不错的结果。 (除了 low 元素可能会得到 stored/reloaded,而不仅仅是就地使用。所以你可能想将 vec[0] 特殊化为 _mm_cvtss_f32(vec) 或其他东西。)

(VCL 以前是在 GPL 下授权的,但现在的版本是一个简单的 Apache 许可证。)

另请参阅我的 github repo,其中对 Agner 的 VCL 进行了大部分未经测试的更改,以便为某些函数生成更好的代码。


有一个 _MM_EXTRACT_FLOAT wrapper macro, but it's weird and only defined with SSE4.1. I think it's intended to go with SSE4.1 extractps (which can extract the binary representation of a float into an integer register, or store to memory). It gcc does compile it into an FP shuffle when the destination is a float, though. Be careful that other compilers don't compile it to an actual extractps instruction if you want the result as a float, because that's not what extractps does. (That is what insertps does,但更简单的 FP 洗牌将占用更少的指令字节。例如shufps 使用 AVX 很棒。)

这很奇怪,因为它需要 3 个参数:_MM_EXTRACT_FLOAT(dest, src_m128, idx),所以你甚至不能将它用作 float 本地的初始化程序。


循环向量

gcc 将为您展开类似的循环,但仅限于 -O1 或更高版本。在-O0,它会给你一个错误信息。

float bad_hsum(__m128 & fv) {
    float sum = 0;
    for (int i=0 ; i<4 ; i++) {
        float f;
        _MM_EXTRACT_FLOAT(f, fv, i);  // works only with -O1 or higher
        sum += f;
    }
    return sum;
}
    float valueAVX(__m256 a, int i){

        float ret = 0;
        switch (i){

            case 0:
//                 a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)      ( a3, a2, a1, a0 )
// cvtss_f32             a0 

                ret = _mm_cvtss_f32(_mm256_extractf128_ps(a, 0));
                break;
            case 1: {
//                     a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)     lo = ( a3, a2, a1, a0 )
// shuffle(lo, lo, 1)      ( - , a3, a2, a1 )
// cvtss_f32                 a1 
                __m128 lo = _mm256_extractf128_ps(a, 0);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(lo, lo, 1));
            }
                break;
            case 2: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)   lo = ( a3, a2, a1, a0 )
// movehl(lo, lo)        ( - , - , a3, a2 )
// cvtss_f32               a2 
                __m128 lo = _mm256_extractf128_ps(a, 0);
                ret = _mm_cvtss_f32(_mm_movehl_ps(lo, lo));
            }
                break;
            case 3: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)   lo = ( a3, a2, a1, a0 )
// shuffle(lo, lo, 3)    ( - , - , - , a3 )
// cvtss_f32               a3 
                __m128 lo = _mm256_extractf128_ps(a, 0);                    
                ret = _mm_cvtss_f32(_mm_shuffle_ps(lo, lo, 3));
            }
                break;

            case 4:
//                 a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)      ( a7, a6, a5, a4 )
// cvtss_f32             a4 
                ret = _mm_cvtss_f32(_mm256_extractf128_ps(a, 1));
                break;
            case 5: {
//                     a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)     hi = ( a7, a6, a5, a4 )
// shuffle(hi, hi, 1)      ( - , a7, a6, a5 )
// cvtss_f32                 a5 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(hi, hi, 1));
            }
                break;
            case 6: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)   hi = ( a7, a6, a5, a4 )
// movehl(hi, hi)        ( - , - , a7, a6 )
// cvtss_f32               a6 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_movehl_ps(hi, hi));
            }
                break;
            case 7: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)   hi = ( a7, a6, a5, a4 )
// shuffle(hi, hi, 3)    ( - , - , - , a7 )
// cvtss_f32               a7 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(hi, hi, 3));
            }
                break;
        }

        return ret;
    }