如何优化点积的 AVX 实现?
How can i optimize my AVX implementation of dot product?
我尝试使用 AVX 实现这两个数组的点积。但是我的代码很慢。
A
和 xb
是双精度数组,n 是偶数。你能帮助我吗?
const int mask = 0x31;
int sum =0;
for (int i = 0; i < n; i++)
{
int ind = i;
if (i + 8 > n) // padding
{
sum += A[ind] * xb[i].x;
i++;
ind = n * j + i;
sum += A[ind] * xb[i].x;
continue;
}
__declspec(align(32)) double ar[4] = { xb[i].x, xb[i + 1].x, xb[i + 2].x, xb[i + 3].x };
__m256d x = _mm256_loadu_pd(&A[ind]);
__m256d y = _mm256_load_pd(ar);
i+=4; ind = n * j + i;
__declspec(align(32)) double arr[4] = { xb[i].x, xb[i + 1].x, xb[i + 2].x, xb[i + 3].x };
__m256d z = _mm256_loadu_pd(&A[ind]);
__m256d w = _mm256_load_pd(arr);
__m256d xy = _mm256_mul_pd(x, y);
__m256d zw = _mm256_mul_pd(z, w);
__m256d temp = _mm256_hadd_pd(xy, zw);
__m128d hi128 = _mm256_extractf128_pd(temp, 1);
__m128d low128 = _mm256_extractf128_pd(temp, 0);
//__m128d dotproduct = _mm_add_pd((__m128d)temp, hi128);
__m128d dotproduct = _mm_add_pd(low128, hi128);
sum += dotproduct.m128d_f64[0]+dotproduct.m128d_f64[1];
i += 3;
}
您的循环中有两个明显的低效率问题:
(1) 这两块标量代码:
__declspec(align(32)) double ar[4] = { xb[i].x, xb[i + 1].x, xb[i + 2].x, xb[i + 3].x };
...
__m256d y = _mm256_load_pd(ar);
和
__declspec(align(32)) double arr[4] = { xb[i].x, xb[i + 1].x, xb[i + 2].x, xb[i + 3].x };
...
__m256d w = _mm256_load_pd(arr);
应该使用 SIMD 加载和随机播放来实现(或者至少使用 _mm256_set_pd
并让编译器有机会为收集的加载生成代码。
(2)循环结束时横向求和:
for (int i = 0; i < n; i++)
{
...
__m256d xy = _mm256_mul_pd(x, y);
__m256d zw = _mm256_mul_pd(z, w);
__m256d temp = _mm256_hadd_pd(xy, zw);
__m128d hi128 = _mm256_extractf128_pd(temp, 1);
__m128d low128 = _mm256_extractf128_pd(temp, 0);
//__m128d dotproduct = _mm_add_pd((__m128d)temp, hi128);
__m128d dotproduct = _mm_add_pd(low128, hi128);
sum += dotproduct.m128d_f64[0]+dotproduct.m128d_f64[1];
i += 3;
}
应该移出循环:
__m256d xy = _mm256_setzero_pd();
__m256d zw = _mm256_setzero_pd();
...
for (int i = 0; i < n; i++)
{
...
xy = _mm256_add_pd(xy, _mm256_mul_pd(x, y));
zw = _mm256_add_pd(zw, _mm256_mul_pd(z, w));
i += 3;
}
__m256d temp = _mm256_hadd_pd(xy, zw);
__m128d hi128 = _mm256_extractf128_pd(temp, 1);
__m128d low128 = _mm256_extractf128_pd(temp, 0);
//__m128d dotproduct = _mm_add_pd((__m128d)temp, hi128);
__m128d dotproduct = _mm_add_pd(low128, hi128);
sum += dotproduct.m128d_f64[0]+dotproduct.m128d_f64[1];
我尝试使用 AVX 实现这两个数组的点积。但是我的代码很慢。
A
和 xb
是双精度数组,n 是偶数。你能帮助我吗?
const int mask = 0x31;
int sum =0;
for (int i = 0; i < n; i++)
{
int ind = i;
if (i + 8 > n) // padding
{
sum += A[ind] * xb[i].x;
i++;
ind = n * j + i;
sum += A[ind] * xb[i].x;
continue;
}
__declspec(align(32)) double ar[4] = { xb[i].x, xb[i + 1].x, xb[i + 2].x, xb[i + 3].x };
__m256d x = _mm256_loadu_pd(&A[ind]);
__m256d y = _mm256_load_pd(ar);
i+=4; ind = n * j + i;
__declspec(align(32)) double arr[4] = { xb[i].x, xb[i + 1].x, xb[i + 2].x, xb[i + 3].x };
__m256d z = _mm256_loadu_pd(&A[ind]);
__m256d w = _mm256_load_pd(arr);
__m256d xy = _mm256_mul_pd(x, y);
__m256d zw = _mm256_mul_pd(z, w);
__m256d temp = _mm256_hadd_pd(xy, zw);
__m128d hi128 = _mm256_extractf128_pd(temp, 1);
__m128d low128 = _mm256_extractf128_pd(temp, 0);
//__m128d dotproduct = _mm_add_pd((__m128d)temp, hi128);
__m128d dotproduct = _mm_add_pd(low128, hi128);
sum += dotproduct.m128d_f64[0]+dotproduct.m128d_f64[1];
i += 3;
}
您的循环中有两个明显的低效率问题:
(1) 这两块标量代码:
__declspec(align(32)) double ar[4] = { xb[i].x, xb[i + 1].x, xb[i + 2].x, xb[i + 3].x };
...
__m256d y = _mm256_load_pd(ar);
和
__declspec(align(32)) double arr[4] = { xb[i].x, xb[i + 1].x, xb[i + 2].x, xb[i + 3].x };
...
__m256d w = _mm256_load_pd(arr);
应该使用 SIMD 加载和随机播放来实现(或者至少使用 _mm256_set_pd
并让编译器有机会为收集的加载生成代码。
(2)循环结束时横向求和:
for (int i = 0; i < n; i++)
{
...
__m256d xy = _mm256_mul_pd(x, y);
__m256d zw = _mm256_mul_pd(z, w);
__m256d temp = _mm256_hadd_pd(xy, zw);
__m128d hi128 = _mm256_extractf128_pd(temp, 1);
__m128d low128 = _mm256_extractf128_pd(temp, 0);
//__m128d dotproduct = _mm_add_pd((__m128d)temp, hi128);
__m128d dotproduct = _mm_add_pd(low128, hi128);
sum += dotproduct.m128d_f64[0]+dotproduct.m128d_f64[1];
i += 3;
}
应该移出循环:
__m256d xy = _mm256_setzero_pd();
__m256d zw = _mm256_setzero_pd();
...
for (int i = 0; i < n; i++)
{
...
xy = _mm256_add_pd(xy, _mm256_mul_pd(x, y));
zw = _mm256_add_pd(zw, _mm256_mul_pd(z, w));
i += 3;
}
__m256d temp = _mm256_hadd_pd(xy, zw);
__m128d hi128 = _mm256_extractf128_pd(temp, 1);
__m128d low128 = _mm256_extractf128_pd(temp, 0);
//__m128d dotproduct = _mm_add_pd((__m128d)temp, hi128);
__m128d dotproduct = _mm_add_pd(low128, hi128);
sum += dotproduct.m128d_f64[0]+dotproduct.m128d_f64[1];