优化元素 2^x-1 的乘法
Optimizing multiplication of elements 2^x-1
是否有已知的乘以已知为 2^x-1 (1, 3, 7...) 的几个(3 到 5)字节 (int8) 的优化
这是在将字节数组与 (2^x-1)/2^x 相乘很多次的情况下。除法是微不足道的(为右移添加指数)但分子有点麻烦。
另外,指数x只有1..31,所有的和总是小于32
// In reality there are 16 of these (i.e. a[16], b[16], c[16])
// ( a + b + c ) < 32
char a = 2;
char b = 16;
char c = 8;
// Ratio/scale, there are 16 of these (i.e. r[16])
// It might work storing in log2 and using int8 or int16
// with fixed point approximation
<x?> r = ( a - 1 ) * ( b - 1 ) * ( c - 1 ) / ( a * b * c );
// Big original value, just one
int v = 1234567890;
// This might be done by scaling down to log2, too
// it is used for a comparison only
// doesn't need full 32b precission
// This is also 16 values, of course (i.e. rv[16])
int rv = v * r;
是不是就这么简单:
a * (2^x - 1) = (a << x) - a
您是否考虑过使用简单的预计算查找 table?如果我没看错你的问题,x0
、x1
、x2
总是在1到31之间,可以五位存储,所以只有2^15 = 32768
种组合。这意味着 r
可以通过几个位移位和按位 OR 来计算,以在相当小的 table.
中计算索引和单次查找
这个 table 查找当然不能向量化。
我所看到的是(与你上次的计算有点相反):
(2^a-1)(2^b-1)(2^c-1)=2^(a+b+c)-2^(a+b)-2^(b+c)-2^(a+c)
+ 2^a + 2^b + 2^c - 1
请注意,展开式中的所有项都是 2 的幂,根据您的约束,所有指数 < 32。当然,所有 32 个可能的术语都可以是 "precomputed"。然后只需总结 2^j 个这样的项(3 <= j <= 5 根据您的约束)。根据我的计算,对于上面的 j=3 情况,abc 有 4 个加法,"lookups" 有 7 个加法,术语有 7 个加法。我不知道这是否比只为您执行 3 "lookups"(2^x-1)和 2 次乘法(咬紧牙关)有所改进...
另请注意:乘以一个因数 2^y-1
可以通过 (y-1)
移位和 (y-1)
加法来完成。假设指数 a,b,c,d,e
,其中 a
最大,那就是 (b+c+d+e-4)
移位和 (b+c+d+e-4)
相加(从 2^a-1
开始)。
坦率地说,这个函数不太适合 AVX 指令集,它缺少整数运算。 SSE2 或 AVX2 提供的直接整数左移几乎肯定是最快的方法。但是,从您对 Aleksander Z. 的回答的评论来看,我了解到您正在寻求评估替代方法。
将这个问题强加到 AVX 单元上需要我们对 IEEE-754 representation 的数字发挥创意。通过未对齐的加载和按位掩码,我们可以将各个字节值混洗到 32 位浮点数的最上面的字节,其中定义数字的 2^n 次方的指数所在。
这几乎让我们得到了我们想要的幂函数,除了我们缺少指数字段的最低有效位并且需要使用平方根来向下调整它。同样我们也需要通过乘法来设置指数偏置。
无论如何,请查看下面的代码以了解详细信息,因为在这里逐字重复注释没有什么意义。请注意,未对齐读取(但忽略)最多三个字节 before 数组,因此请根据需要添加填充。另请注意,结果字是交错的,result1 存储字节 {0,4,8,12,..} 等等。
哦,显然结果将是使用浮点运算的近似值。
void compute(const unsigned char (*ptr)[32], size_t len) {
const __m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x3F000000U));
const __m256 normalize = _mm256_castsi256_ps(_mm256_set1_epi32(0x7F000000U));
const __m256 offset = _mm256_set1_ps(1);
__m256 result1 = _mm256_set1_ps(1);
__m256 result2 = _mm256_set1_ps(1);
__m256 result3 = _mm256_set1_ps(1);
__m256 result4 = _mm256_set1_ps(1);
do {
// Mask out every forth byte into a separate variable using unaligned
// loads to simulate 8-to-32 bit integer unpacking
__m256 real1 = _mm256_loadu_ps((const float *) &ptr[0][-3]);
__m256 real2 = _mm256_loadu_ps((const float *) &ptr[0][-2]);
__m256 real3 = _mm256_loadu_ps((const float *) &ptr[0][-1]);
__m256 real4 = _mm256_loadu_ps((const float *) &ptr[0][-0]);
real1 = _mm256_and_ps(real1, mask);
real2 = _mm256_and_ps(real2, mask);
real3 = _mm256_and_ps(real3, mask);
real4 = _mm256_and_ps(real4, mask);
// The binary values are 2^2x * 2^-BIAS once the masked-once top bytes
// are interpreted as IEEE-754 floating-point exponent bytes.
// Unfortunately we are overshooting the exponent field by one bit,
// hence the doubled exponents. Anyway, let's at least multiply the
// bias away
real1 = _mm256_mul_ps(real1, normalize);
real2 = _mm256_mul_ps(real2, normalize);
real3 = _mm256_mul_ps(real3, normalize);
real4 = _mm256_mul_ps(real4, normalize);
// Use a fast aproximate reciprocal square root to halve the exponent,
// yielding ~1/2^x.
// You'd think this case of the reciprocal lookup table would be
// precise, yet it seems not to be. Perhaps twiddling the rounding
// mode or biasing the values may make it so.
real1 = _mm256_rsqrt_ps(real1);
real2 = _mm256_rsqrt_ps(real2);
real3 = _mm256_rsqrt_ps(real3);
real4 = _mm256_rsqrt_ps(real4);
// Compute (2^x-1)/2^x as 1-1/2^x
real1 = _mm256_sub_ps(offset, real1);
real2 = _mm256_sub_ps(offset, real2);
real3 = _mm256_sub_ps(offset, real3);
real4 = _mm256_sub_ps(offset, real4);
// Finally multiply the running products
result1 = _mm256_mul_ps(result1, real1);
result2 = _mm256_mul_ps(result2, real2);
result3 = _mm256_mul_ps(result3, real3);
result4 = _mm256_mul_ps(result4, real4);
} while(++ptr, --len);
/*
* Do something useful with result1..4 here
*/
}
像往常一样,问题要求优化解决方案,而不是关于如何解决原始问题的建议。这很烦人。
只使用了32个唯一乘数mx:
m0 = (20-1)/20 = 0/1 = 0
m1 = (21-1)/21 = 1/2 = 0.5
m2 = (22-1)/22 = 3/4 = 0.75
m3 = (23-1)/23 = 7/8 = 0.875
m4 = (24-1)/24 = 15/16 = 0.9375
m5 = (25-1)/25 = 31/32 = 0.96875
m6 = (26-1)/26 = 63/64 = 0.984375
m7 = (27-1)/27 = 127/128 = 0.9921875
m8 = (28-1)/28 = 255/256 = 0.99609375
m9 = (29-1)/29 = 511/512 = 0.998046875
m10 = (210-1)/210 = 1023/1024 = 0.9990234375
m11 = (211-1)/211 = 2047/2048 = 0.99951171875
m12 = (212-1)/212 = 4095/4096 = 0.999755859375
m13 = (213-1)/213 = 8191/8192 = 0.9998779296875
m14 = (214-1)/214 = 16383/16384 = 0.99993896484375
m15 = (215-1)/215 = 32767/32768 = 0.999969482421875
m16 = (216-1)/216 = 65535/65536 = 0.9999847412109375
m17 = (217-1)/217 = 131071/131072 = 0.99999237060546875
m18 = (218-1)/218 = 262143/262144 = 0.999996185302734375
m19 = (219-1)/219 = 524287/524288 = 0.9999980926513671875
m20 = (220-1)/220 = 1048575/1048576 = 0.99999904632568359375
m21 = (221-1)/221 = 2097151/2097152 = 0.999999523162841796875
m22 = (222-1)/222 = 4194303/4194304 = 0.9999997615814208984375
m23 = (223-1)/223 = 8388607/8388608 = 0.99999988079071044921875
m24 = (224-1)/224 = 16777215/16777216 = 0.999999940395355224609375
m25 = (225-1)/225 = 33554431/33554432 = 0.9999999701976776123046875
m26 = (226-1)/226 = 67108863/67108864 = 0.99999998509883880615234375
m27 = (227-1)/227 = 134217727/134217728 = 0.999999992549419403076171875
m28 = (228-1)/228 = 268435455/268435456 = 0.9999999962747097015380859375
m29 = (229-1)/229 = 536870911/536870912 = 0.99999999813735485076904296875
m30 = (230-1)/230 = 1073741823/1073741824 = 0.999999999068677425384521484375
m31 = (231-1)/231 = 2147483647/2147483648 = 0.9999999995343387126922607421875
以上小数均准确无误
三到五个乘数(上面table)相乘,最后加上一个"big number",得到最后的结果。
简单的查找 table 需要 32 个条目。包含两个乘法器乘积的查找 table 需要 322 = 1,024 个条目。包含三个乘法器乘积的查找 table 需要 323 = 32,768 个条目。四乘法器需要 1,048,576 个条目,并且通常太大而无法在当前处理器上实现缓存效率。
使用 Binary32 前 25 个条目(m0 到 m 24,包括)是准确的,但最后七个(m25 到 m31,含)无法表示,求值为1。因此,如果每个x被限制在范围[0 , 24],那么 binary32 系数就足够了。另外,"big number"乘以复合系数也只有七位左右的有效数字。
对于 Binary64,乘数将是精确的,并且 "big number" 将至少有 17 位有效数字。
SSE 向量 (__m128) 包含四个 Binary32 浮点数,AVX 向量 (__m256) 八个;分别是两个和四个 Binary64。如果要计算 16 个 "big numbers",这意味着两个、四个或八个矢量字,具体取决于体系结构和格式。
假设您改用 Binary64,并使用对齐的 table 两个 (mul2[32][32]
) 和三个 (mul3[32][32][32]
) 乘数乘积。在向量中计算 16 "big numbers" 的整体伪代码将归结为
将大数转换为 Binary64 向量,比如 num。
将大数最终系数打包成向量:
如果一个数有三个乘数a,b,c,加载
c1 = mul3[a][b][ c]
c2 = 1.0
如果一个数有四个乘数a,b,c,d, 加载
c1 = mul2[a][b]
c2 = mul2[c][d]
如果一个数有五个乘数a,b,c,d, e, 加载
c1 = mul3[a][b][ c]
c2 = mul2[d][e]
计算(向量相乘)
结果 = num · c1 · c2
如果需要,四舍五入到最接近的整数。
如果所有 num 都是非负数,只需将每个 result 加 0.5,以便以下截断正确。
截断 结果 并将其存储为 32 位或 64 位整数。
请注意,由于 table 查找,您最多只需要两次乘法——这些是向量乘法。 table 查找有点令人担忧,因为它必须分别为每个组件完成。
许多当前的处理器每个内核都有两个 ALU,因此展开和交错上述内容以便您一次对两个(或更多)不同的向量词进行操作通常会比简单地向量化计算产生显着的改进。这也意味着拥有比单个向量词所能容纳的更多数据是一件好事;它可以提高整体性能。
同样的方法显然也适用于 Binary32。
由于 table 查找可能是另一个瓶颈,交错整个操作,以便在进行向量乘法时在查找阶段开始处理另一个向量词对,应该让你从 CPU,尽管生成的代码理解起来有点复杂。
如果您选择这条路线,我强烈建议您将完整的伪代码描述作为注释添加到执行此操作的函数中。对以后的代码维护有很大帮助
是否有已知的乘以已知为 2^x-1 (1, 3, 7...) 的几个(3 到 5)字节 (int8) 的优化
这是在将字节数组与 (2^x-1)/2^x 相乘很多次的情况下。除法是微不足道的(为右移添加指数)但分子有点麻烦。
另外,指数x只有1..31,所有的和总是小于32
// In reality there are 16 of these (i.e. a[16], b[16], c[16])
// ( a + b + c ) < 32
char a = 2;
char b = 16;
char c = 8;
// Ratio/scale, there are 16 of these (i.e. r[16])
// It might work storing in log2 and using int8 or int16
// with fixed point approximation
<x?> r = ( a - 1 ) * ( b - 1 ) * ( c - 1 ) / ( a * b * c );
// Big original value, just one
int v = 1234567890;
// This might be done by scaling down to log2, too
// it is used for a comparison only
// doesn't need full 32b precission
// This is also 16 values, of course (i.e. rv[16])
int rv = v * r;
是不是就这么简单:
a * (2^x - 1) = (a << x) - a
您是否考虑过使用简单的预计算查找 table?如果我没看错你的问题,x0
、x1
、x2
总是在1到31之间,可以五位存储,所以只有2^15 = 32768
种组合。这意味着 r
可以通过几个位移位和按位 OR 来计算,以在相当小的 table.
这个 table 查找当然不能向量化。
我所看到的是(与你上次的计算有点相反):
(2^a-1)(2^b-1)(2^c-1)=2^(a+b+c)-2^(a+b)-2^(b+c)-2^(a+c)
+ 2^a + 2^b + 2^c - 1
请注意,展开式中的所有项都是 2 的幂,根据您的约束,所有指数 < 32。当然,所有 32 个可能的术语都可以是 "precomputed"。然后只需总结 2^j 个这样的项(3 <= j <= 5 根据您的约束)。根据我的计算,对于上面的 j=3 情况,abc 有 4 个加法,"lookups" 有 7 个加法,术语有 7 个加法。我不知道这是否比只为您执行 3 "lookups"(2^x-1)和 2 次乘法(咬紧牙关)有所改进...
另请注意:乘以一个因数 2^y-1
可以通过 (y-1)
移位和 (y-1)
加法来完成。假设指数 a,b,c,d,e
,其中 a
最大,那就是 (b+c+d+e-4)
移位和 (b+c+d+e-4)
相加(从 2^a-1
开始)。
坦率地说,这个函数不太适合 AVX 指令集,它缺少整数运算。 SSE2 或 AVX2 提供的直接整数左移几乎肯定是最快的方法。但是,从您对 Aleksander Z. 的回答的评论来看,我了解到您正在寻求评估替代方法。
将这个问题强加到 AVX 单元上需要我们对 IEEE-754 representation 的数字发挥创意。通过未对齐的加载和按位掩码,我们可以将各个字节值混洗到 32 位浮点数的最上面的字节,其中定义数字的 2^n 次方的指数所在。
这几乎让我们得到了我们想要的幂函数,除了我们缺少指数字段的最低有效位并且需要使用平方根来向下调整它。同样我们也需要通过乘法来设置指数偏置。
无论如何,请查看下面的代码以了解详细信息,因为在这里逐字重复注释没有什么意义。请注意,未对齐读取(但忽略)最多三个字节 before 数组,因此请根据需要添加填充。另请注意,结果字是交错的,result1 存储字节 {0,4,8,12,..} 等等。
哦,显然结果将是使用浮点运算的近似值。
void compute(const unsigned char (*ptr)[32], size_t len) {
const __m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x3F000000U));
const __m256 normalize = _mm256_castsi256_ps(_mm256_set1_epi32(0x7F000000U));
const __m256 offset = _mm256_set1_ps(1);
__m256 result1 = _mm256_set1_ps(1);
__m256 result2 = _mm256_set1_ps(1);
__m256 result3 = _mm256_set1_ps(1);
__m256 result4 = _mm256_set1_ps(1);
do {
// Mask out every forth byte into a separate variable using unaligned
// loads to simulate 8-to-32 bit integer unpacking
__m256 real1 = _mm256_loadu_ps((const float *) &ptr[0][-3]);
__m256 real2 = _mm256_loadu_ps((const float *) &ptr[0][-2]);
__m256 real3 = _mm256_loadu_ps((const float *) &ptr[0][-1]);
__m256 real4 = _mm256_loadu_ps((const float *) &ptr[0][-0]);
real1 = _mm256_and_ps(real1, mask);
real2 = _mm256_and_ps(real2, mask);
real3 = _mm256_and_ps(real3, mask);
real4 = _mm256_and_ps(real4, mask);
// The binary values are 2^2x * 2^-BIAS once the masked-once top bytes
// are interpreted as IEEE-754 floating-point exponent bytes.
// Unfortunately we are overshooting the exponent field by one bit,
// hence the doubled exponents. Anyway, let's at least multiply the
// bias away
real1 = _mm256_mul_ps(real1, normalize);
real2 = _mm256_mul_ps(real2, normalize);
real3 = _mm256_mul_ps(real3, normalize);
real4 = _mm256_mul_ps(real4, normalize);
// Use a fast aproximate reciprocal square root to halve the exponent,
// yielding ~1/2^x.
// You'd think this case of the reciprocal lookup table would be
// precise, yet it seems not to be. Perhaps twiddling the rounding
// mode or biasing the values may make it so.
real1 = _mm256_rsqrt_ps(real1);
real2 = _mm256_rsqrt_ps(real2);
real3 = _mm256_rsqrt_ps(real3);
real4 = _mm256_rsqrt_ps(real4);
// Compute (2^x-1)/2^x as 1-1/2^x
real1 = _mm256_sub_ps(offset, real1);
real2 = _mm256_sub_ps(offset, real2);
real3 = _mm256_sub_ps(offset, real3);
real4 = _mm256_sub_ps(offset, real4);
// Finally multiply the running products
result1 = _mm256_mul_ps(result1, real1);
result2 = _mm256_mul_ps(result2, real2);
result3 = _mm256_mul_ps(result3, real3);
result4 = _mm256_mul_ps(result4, real4);
} while(++ptr, --len);
/*
* Do something useful with result1..4 here
*/
}
像往常一样,问题要求优化解决方案,而不是关于如何解决原始问题的建议。这很烦人。
只使用了32个唯一乘数mx:
m0 = (20-1)/20 = 0/1 = 0
m1 = (21-1)/21 = 1/2 = 0.5
m2 = (22-1)/22 = 3/4 = 0.75
m3 = (23-1)/23 = 7/8 = 0.875
m4 = (24-1)/24 = 15/16 = 0.9375
m5 = (25-1)/25 = 31/32 = 0.96875
m6 = (26-1)/26 = 63/64 = 0.984375
m7 = (27-1)/27 = 127/128 = 0.9921875
m8 = (28-1)/28 = 255/256 = 0.99609375
m9 = (29-1)/29 = 511/512 = 0.998046875
m10 = (210-1)/210 = 1023/1024 = 0.9990234375
m11 = (211-1)/211 = 2047/2048 = 0.99951171875
m12 = (212-1)/212 = 4095/4096 = 0.999755859375
m13 = (213-1)/213 = 8191/8192 = 0.9998779296875
m14 = (214-1)/214 = 16383/16384 = 0.99993896484375
m15 = (215-1)/215 = 32767/32768 = 0.999969482421875
m16 = (216-1)/216 = 65535/65536 = 0.9999847412109375
m17 = (217-1)/217 = 131071/131072 = 0.99999237060546875
m18 = (218-1)/218 = 262143/262144 = 0.999996185302734375
m19 = (219-1)/219 = 524287/524288 = 0.9999980926513671875
m20 = (220-1)/220 = 1048575/1048576 = 0.99999904632568359375
m21 = (221-1)/221 = 2097151/2097152 = 0.999999523162841796875
m22 = (222-1)/222 = 4194303/4194304 = 0.9999997615814208984375
m23 = (223-1)/223 = 8388607/8388608 = 0.99999988079071044921875
m24 = (224-1)/224 = 16777215/16777216 = 0.999999940395355224609375
m25 = (225-1)/225 = 33554431/33554432 = 0.9999999701976776123046875
m26 = (226-1)/226 = 67108863/67108864 = 0.99999998509883880615234375
m27 = (227-1)/227 = 134217727/134217728 = 0.999999992549419403076171875
m28 = (228-1)/228 = 268435455/268435456 = 0.9999999962747097015380859375
m29 = (229-1)/229 = 536870911/536870912 = 0.99999999813735485076904296875
m30 = (230-1)/230 = 1073741823/1073741824 = 0.999999999068677425384521484375
m31 = (231-1)/231 = 2147483647/2147483648 = 0.9999999995343387126922607421875
以上小数均准确无误
三到五个乘数(上面table)相乘,最后加上一个"big number",得到最后的结果。
简单的查找 table 需要 32 个条目。包含两个乘法器乘积的查找 table 需要 322 = 1,024 个条目。包含三个乘法器乘积的查找 table 需要 323 = 32,768 个条目。四乘法器需要 1,048,576 个条目,并且通常太大而无法在当前处理器上实现缓存效率。
使用 Binary32 前 25 个条目(m0 到 m 24,包括)是准确的,但最后七个(m25 到 m31,含)无法表示,求值为1。因此,如果每个x被限制在范围[0 , 24],那么 binary32 系数就足够了。另外,"big number"乘以复合系数也只有七位左右的有效数字。
对于 Binary64,乘数将是精确的,并且 "big number" 将至少有 17 位有效数字。
SSE 向量 (__m128) 包含四个 Binary32 浮点数,AVX 向量 (__m256) 八个;分别是两个和四个 Binary64。如果要计算 16 个 "big numbers",这意味着两个、四个或八个矢量字,具体取决于体系结构和格式。
假设您改用 Binary64,并使用对齐的 table 两个 (mul2[32][32]
) 和三个 (mul3[32][32][32]
) 乘数乘积。在向量中计算 16 "big numbers" 的整体伪代码将归结为
将大数转换为 Binary64 向量,比如 num。
将大数最终系数打包成向量:
如果一个数有三个乘数a,b,c,加载
c1 = mul3[a][b][ c]
c2 = 1.0如果一个数有四个乘数a,b,c,d, 加载
c1 = mul2[a][b]
c2 = mul2[c][d]如果一个数有五个乘数a,b,c,d, e, 加载
c1 = mul3[a][b][ c]
c2 = mul2[d][e]计算(向量相乘)
结果 = num · c1 · c2
如果需要,四舍五入到最接近的整数。
如果所有 num 都是非负数,只需将每个 result 加 0.5,以便以下截断正确。
截断 结果 并将其存储为 32 位或 64 位整数。
请注意,由于 table 查找,您最多只需要两次乘法——这些是向量乘法。 table 查找有点令人担忧,因为它必须分别为每个组件完成。
许多当前的处理器每个内核都有两个 ALU,因此展开和交错上述内容以便您一次对两个(或更多)不同的向量词进行操作通常会比简单地向量化计算产生显着的改进。这也意味着拥有比单个向量词所能容纳的更多数据是一件好事;它可以提高整体性能。
同样的方法显然也适用于 Binary32。
由于 table 查找可能是另一个瓶颈,交错整个操作,以便在进行向量乘法时在查找阶段开始处理另一个向量词对,应该让你从 CPU,尽管生成的代码理解起来有点复杂。
如果您选择这条路线,我强烈建议您将完整的伪代码描述作为注释添加到执行此操作的函数中。对以后的代码维护有很大帮助