任何人都可以修改以下代码以避免使用 "if" 吗?

Can anyone modify the following code to avoid the usage of "if"?

我目前正在研究 cuda,我被下面的代码卡住了。该代码最初是用 matlab 编写的,我正在尝试使用 cuda 重写它:

Pv = 0; Nv =0;
[LOOP]
v1 = l(li);
v2 = l(li+1);
if((v1>0) && (v2>0))
    Pv = Pv + 1;
elseif((v1<0) && (v2<0))
    Nv = Nv +1;
elseif((v1>0) && (v2<0))
    r = v1/(v1-v2);
    Pv = Pv + r;
    Nv = Nv + 1 - r;
elseif((v1<0) && (v2>0))
    r = v2/(v2-v1);
    Pv = Pv + r;
    Nv = Nv + 1 - r;
end
[LOOP END]

但是,在 cuda 体系结构中,"if" 表达式有时是昂贵的,我相信有一些方法可以避免使用它,虽然我现在无法弄清楚。

代码的主要目的是计算正区间和负区间的比值,分别相加。在大多数情况下,v1 和 v2 具有相同的符号,但是一旦它们具有不同的符号,我必须使用一堆 "if" 甚至 "abs()" 来处理这种情况。

所以,谁能帮我用 C 重写代码,同时尽可能少地使用 "if "?

利用您使用的每个表达式都给出一个 0 或 1 二进制表达式这一事实,我可以将其归结为一个三元运算符:

// r is 0 if not case 3 or 4
r = (v1==v2)?0:((v1>v2)*v1-(v2>v1)*v2)/(v1-v2) * (v1*v2<0);
Pv += r + // cases 3 & 4
     (v1>0) && (v2>0); // case 1
Nv += r + // cases 3 & 4
     (v1<=0) || (v2<=0); // cases 2, 3 and 4

(编辑:进一步优化未经测试!)

但是,这听起来很像过早的优化。 if 语句的数量真的会导致问题吗?

添加到ablish's :

这个

r = (v1 == v2) ?0 :((v1 > v2) * v1 - (v2 > v1) * v2) / (v1 - v2) * (v1 * v2 < 0);

可以换成这个

double ra[2] = {((v1 > v2) * v1 - (v2 > v1) * v2) / (v1 - v2) * (v1 * v2 < 0), 0.};
r = ra[!!(v1 == v2)];

删除最后一个三元运算符。


更新:

!! 没有用,所以上面的第二行应该是:

r = ra[v1 == v2];

正如 abligh 中指出的那样,我们可能会在 ra 的初始化期间得到除以零的结果,所以我们只想做适当计算:

double r0(double v1, double v2)
{
  return 0;
}

double r1(double v1, double v2)
{
  return ((v1 > v2) * v1 - (v2 > v1) * v2) / (v1 - v2) * (v1 * v2 < 0);
}

double (*rf[2])(double v1, double v2) = {r0, r1}; 

...

r = rf[fabs(v1 - v2) < EPSILON](v1, v2); /* With EPSILON like 0.00001 or what ever accuracy is needed. */

既然你说类似符号的非零操作数的情况是最常见的情况,最好使用分支来处理罕见的情况 3 和 4,特别是因为它们需要昂贵的除法运算我们不想在公共(快速)路径中执行。考虑到 min()max() 以及三元运算符直接由 GPU 中的硬件支持,并且 GPU 对预测有广泛的支持,我建议稍微重写如下所示。我已经使用从集合 {-0.0f, -3.0f, -5.0f, 0.0f, 3.0f, 5.0f} 中抽取的两个元素的随机组合对此进行了测试,以确保涵盖 v1==0v2==0 的情况。

float Nv, Pv, v1, v2;
float r;
if ((v1 > 0) && (v2 > 0)) {
    Pv = Pv + 1;
} 
if ((v1 < 0) && (v2 < 0)) {
    Nv = Nv + 1;
}
if (((v1 > 0) && (v2 < 0)) || ((v1 < 0) && (v2 > 0))) { // rare case
    float s = min (v1, v2);
    float t = max (v1, v2);
    r = t / (t - s);
    Pv = Pv + r;
    Nv = Nv + 1 - r;
}

为sm_20编译,编译器生成的代码是无分支的,除了罕见的慢路径:

   /*0008*/         FSETP.LT.AND P2, PT, RZ, c[0x0][0x2c], PT;
   /*0010*/         FSETP.GT.AND P3, PT, RZ, c[0x0][0x28], PT;
   /*0018*/         FSETP.LT.AND P1, PT, RZ, c[0x0][0x28], PT;
   /*0020*/         MOV32I R4, 0x3f800000;
   /*0028*/         FSETP.GT.AND P4, PT, RZ, c[0x0][0x2c], PT;
   /*0030*/         PSETP.AND.AND P5, PT, P3, P2, PT;
   /*0038*/         PSETP.AND.AND P0, PT, P1, P2, PT;
   /*0040*/         FADD R0, R4, c[0x0][0x24];
   /*0048*/         PSETP.AND.AND P2, PT, P3, P4, PT;
   /*0050*/         FADD R4, R4, c[0x0][0x20];
   /*0058*/         PSETP.AND.OR P1, PT, P1, P4, P5;
   /*0060*/         MOV R2, c[0x0][0x30];
   /*0068*/         MOV R3, c[0x0][0x34];
   /*0070*/         SEL R0, R0, c[0x0][0x24], P0;
   /*0078*/         SEL R6, R4, c[0x0][0x20], P2;
   /*0080*/    @!P1 BRA 0xc8;
   /*0088*/         MOV R4, c[0x0][0x28];
   /*0090*/         FMNMX R5, R4, c[0x0][0x2c], PT;
   /*0098*/         FMNMX R4, R4, c[0x0][0x2c], !PT;
   /*00a0*/         FADD R5, R4, -R5;
   /*00a8*/         CAL 0xe0;                 // division
   /*00b0*/         FADD R5, R6, 1;
   /*00b8*/         FADD R0, R0, R4;
   /*00c0*/         FADD R6, R5, -R4;
   /*00c8*/         FADD R0, R0, R6;