使用 FMA(融合乘法)指令进行复数乘法

Using FMA (fused multiply) instructions for complex multiplication

我想利用可用的融合乘法 add/subtract CPU 指令来协助对适当大小的数组进行复数乘法。本质上,基础数学看起来像这样:

void ComplexMultiplyAddToArray(float* pDstR, float* pDstI, const float* pSrc1R, const float* pSrc1I, const float* pSrc2R, const float* pSrc2I, int len)
{
    for (int i = 0; i < len; ++i)
    {
        const float fSrc1R = pSrc1R[i];
        const float fSrc1I = pSrc1I[i];
        const float fSrc2R = pSrc2R[i];
        const float fSrc2I = pSrc2I[i];

        //  Perform complex multiplication on the input and accumulate with the output
        pDstR[i] += fSrc1R*fSrc2R - fSrc1I*fSrc2I;
        pDstI[i] += fSrc1R*fSrc2I + fSrc2R*fSrc1I;
    }
}

正如您可能看到的那样,数据是结构化的,其中我们有单独的实数和虚数数组。现在,假设我有以下函数可用作分别执行 ab+c 和 ab-c 的单个指令的内在函数:

float fmadd(float a, float b, float c);
float fmsub(float a, float b, float c);

天真地,我可以看到我可以用一个 fmadd 和一个 fmsub 替换 2 个乘法、一个加法和一个减法,就像这样:

//  Perform complex multiplication on the input and accumulate with the output
pDstR[i] += fmsub(fSrc1R, fSrc2R, fSrc1I*fSrc2I);
pDstI[i] += fmadd(fSrc1R, fSrc2I, fSrc2R*fSrc1I);

这会带来非常适度的性能改进,我假设还会提高准确性,但我认为我确实缺少一些可以对数学进行代数修改的东西,这样我就可以再替换几个 mult/add或 mult/sub 组合。在每一行中,都有一个额外的加法和一个额外的乘法,我觉得我可以将它们转换为单个 fma,但令人沮丧的是,我不知道如何在不改变操作顺序和得到错误结果的情况下做到这一点。有想法的数学专家吗?

就问题而言,目标平台可能并不那么重要,因为我知道这些指令存在于各种平台上。

这是一个好的开始。可以再减一项:

//  Perform complex multiplication on the input and accumulate with the output
pDstR[i] += fmsub(fSrc1R, fSrc2R, fSrc1I*fSrc2I);
pDstI[i] += fmadd(fSrc1R, fSrc2I, fSrc2R*fSrc1I);

这里可以利用另一个fmadd计算虚部:

pDstI[i] = fmadd(fSrc1R, fSrc2I, fmadd(fSrc2R, fSrc1I, pDstI[i]));

同样你可以对实部做同样的事情,但是你需要否定这个论点。如果这会使事情变得更快或更慢,这在很大程度上取决于您正在处理的架构的微时序:

pDstR[i] = fmsub(fSrc1R, fSrc2R, fmadd(fSrc1I, fSrc2I, -pDstR[i]));

顺便说一句,如果您使用 restrict 关键字将目标数组声明为非别名,您 可能 获得进一步的性能改进。现在编译器必须假定 pDstR 和 pDstI 可能重叠或指向同一块内存。这将阻止编译器在写入 pDstR[i] 之前加载 pDstI[i]。

之后,如果编译器尚未这样做,一些仔细的循环展开也可能有所帮助。检查编译器的汇编输出!

我发现以下内容(在一些帮助下)似乎是正确答案:

pDstR[i] = fmsub(fSrc1R, fSrc2R, fmsub(fSrc1I, fSrc2I, pDstR[i]));
pDstI[i] = fmadd(fSrc1R, fSrc2I, fmadd(fSrc2R, fSrc1I, pDstI[i]));

但奇怪的是,在 AVX 上并没有像使用半 FMA 保留数学的真实结果部分,但让虚构结果使用完整 FMA 那样提高性能:

pDstR[i] += fmsub(fSrc1R, fSrc2R, fSrc1I*fSrc2I);
pDstI[i] = fmadd(fSrc1R, fSrc2I, fmadd(fSrc2R, fSrc1I, pDstI[i]));

感谢大家的帮助。