如何优化长系列的 If/then 条件表达式 - SIMD

How to optimization long series of If/then conditional expressions - SIMD

我正在使用 SIMD 来提高 C 代码的性能,但我遇到了一个具有许多 if/then 条件的函数,如下所示:

if (Di <= -T3) return  -4;
if (Di <= -T2) return  -3;
if (Di <= -T1) return  -2;
if (Di < -NEAR)  return  -1;
if (Di <=  NEAR) return   0;
if (Di < T1)   return   1;
if (Di < T2)   return   2;
if (Di < T3)   return   3;

return  4;

使用 VC++ 编译器支持的英特尔内部函数会导致处理时间变慢。

那么有没有更好的方法来优化这一长串条件表达式?

您可以尝试完全摆脱条件并重新测量时间。 您的代码

if (Di <= -T3) return  -4;
if (Di <= -T2) return  -3;
if (Di <= -T1) return  -2;
if (Di < -NEAR)  return  -1;
if (Di <=  NEAR) return   0;
if (Di < T1)   return   1;
if (Di < T2)   return   2;
if (Di < T3)   return   3;

return  4;

可以转化为无条件形式:

return
    (Di <= -T3)*(-4) + (Di > -T3) * (
    (Di <= -T2)*(-3) + (Di > -T2) * (
    (Di <= -T1)*(-2) + (Di > -T1) * (
    (Di < -NEAR)*(-1) + (Di >= -NEAR) * (
    (Di <=  NEAR)*0 + (Di > NEAR) * (
    (Di < T1)*1 + (Di >= T1) * (
    (Di < T2)*2 + (Di >= T2) * (
    (Di < T3)*3 + (Di >= T3) * (
    4
    ))))))));

或许,您可以进一步优化此代码,了解您的变量的可能内容。

我假设几件事:

  1. 您处理的是 int32 数据(尽管它可以很容易地更改为 float32)。
  2. 您可以一次将 4 个值传递给您的函数(不只是一个)。这就是人们通常所说的矢量化
  3. 常量已排序,即 0 < NEAR < T1 < T2 < T3。

这是一个矢量化函数:

__m128i func4(__m128i D) {
  __m128i cmp_m3 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T3));
  __m128i cmp_m2 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T2));
  __m128i cmp_m1 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T1));
  __m128i cmp_p0 = _mm_cmpgt_epi32(D, _mm_set1_epi32(NEAR));
  __m128i reduce_true = _mm_add_epi32(_mm_add_epi32(cmp_m3, cmp_m2), _mm_add_epi32(cmp_m1, cmp_p0));
  __m128i cmp_m0 = _mm_cmplt_epi32(D, _mm_set1_epi32(-NEAR));
  __m128i cmp_p1 = _mm_cmplt_epi32(D, _mm_set1_epi32(T1));
  __m128i cmp_p2 = _mm_cmplt_epi32(D, _mm_set1_epi32(T2));
  __m128i cmp_p3 = _mm_cmplt_epi32(D, _mm_set1_epi32(T3));
  __m128i reduce_false = _mm_add_epi32(_mm_add_epi32(cmp_p3, cmp_p2), _mm_add_epi32(cmp_p1, cmp_m0));
  return _mm_sub_epi32(reduce_false, reduce_true);
}

如果输入数据是随机的,那么它的运行速度比使用 MSVC2013 x64 的 Ivy Bridge 上的原始版本快 11 倍:

Time = 4.436   (-39910000)
Time = 0.409   (-39910000)

带有测试的完整代码可用 here

这个想法很简单。 您可以在上面 link 之后的函数 funcX 中看到建议解决方案的非矢量化版本。它可能比语言更能解释一切。

我们将一个寄存器D作为输入,它包含4个打包值。 然后我们将它与你拥有的所有 8 个常量与 _mm_cmp* 内在的进行比较。此比较产生 8 个位掩码 cmp_pXcmp_mX。在位掩码中,与数字对应的所有位都为零或一。为每个比较设置 32 个零位,这是错误的。如果比较条件为真,则32位设置为1。

现在回想一下,全为 1 的 32 位整数在有符号表示中为 -1。当我们将四个比较结果加在一起时,我们得到一组否定的计数。最后,我们取两个计数的差值,就是我们想要的结果。

P.S。下面是为内循环生成的汇编代码:

movdqa  xmm3, XMMWORD PTR [rcx]
movdqa  xmm4, xmm10
movdqa  xmm0, xmm9
add rcx, 16
pcmpgtd xmm0, xmm3
pcmpgtd xmm4, xmm3
paddd   xmm4, xmm0
movdqa  xmm2, xmm3
movdqa  xmm1, xmm8
pcmpgtd xmm1, xmm3
pcmpgtd xmm2, xmm14
movdqa  xmm0, xmm7
pcmpgtd xmm0, xmm3
paddd   xmm1, xmm0
paddd   xmm4, xmm1
movdqa  xmm0, xmm3
movdqa  xmm1, xmm3
pcmpgtd xmm1, xmm12
pcmpgtd xmm0, xmm13
pcmpgtd xmm3, xmm11
paddd   xmm1, xmm3
paddd   xmm2, xmm0
paddd   xmm2, xmm1
psubd   xmm4, xmm2
paddd   xmm4, xmm5
movdqa  xmm5, xmm4
cmp rcx, r15
jl  SHORT $LL3@main