提高矩阵乘法的性能

Improving the performance of Matrix Multiplication

这是我的加速矩阵乘法的代码,但它只比简单的快 5%。 我能做些什么来尽可能地提高它?

*正在访问这些表,例如:C[sub2ind(i,j,n)] for the C[i, j]位置。

void matrixMultFast(float * const C,            /* output matrix */
                float const * const A,      /* first matrix */
                float const * const B,      /* second matrix */
                int const n,                /* number of rows/cols */
                int const ib,               /* size of i block */
                int const jb,               /* size of j block */
                int const kb)               /* size of k block */
{

int i=0, j=0, jj=0, k=0, kk=0;
float sum;

for(i=0;i<n;i++)
    for(j=0;j<n;j++)
        C[sub2ind(i,j,n)]=0;

for(kk=0;kk<n;kk+=kb)
{
    for(jj=0;jj<n;jj+=jb)
    {
        for(i=0;i<n;i++)
        {
            for(j=jj;j<jj+jb;j++)
            {
                sum=C[sub2ind(i,j,n)];
                for(k=kk;k<kk+kb;k++)
                    sum += A[sub2ind(i,k,n)]*B[sub2ind(k,j,n)];
                C[sub2ind(i,j,n)]=sum;
            }
        }
    }
}
} // end function 'matrixMultFast4'

*C语言编写,需要支持C99

你可以做很多很多事情来提高矩阵乘法的效率。

为了研究如何改进基本算法,让我们先来看看我们当前的选择。当然,天真的实现有 3 个循环,时间复杂度约为 O(n^3)。还有另一种方法称为 Strassen 方法,它实现了可观的加速并具有 O(n^2.73) 的顺序(但在实践中是无用的,因为它没有提供可感知的优化方法)。

这是理论上的。现在考虑矩阵是如何存储在内存中的。行专业是标准的,但你也可以找到列专业。根据方案的不同,转置矩阵可能会由于缓存未命中次数减少而提高速度。理论上矩阵乘法只是一堆向量点积和加法。多个向量将对同一个向量进行操作,因此将该向量保存在缓存中以便更快地访问是有意义的。

现在,随着多核、并行性和缓存概念的引入,游戏规则发生了变化。如果我们仔细观察,我们会发现点积只不过是一堆乘法,然后是求和。这些乘法可以并行进行。因此,我们现在可以看看数字的并行加载。

现在让我们把事情变得更复杂一点。在谈论矩阵乘法时,单浮点数和双浮点数在大小上有所区别。通常前者是32位而后者是64位(当然,这取决于CPU)。每个 CPU 只有固定数量的寄存器,这意味着你的数字越大,你能容纳在 CPU 中的空间就越小。这个故事的寓意是,除非你真的需要双浮点数,否则坚持使用单浮点数。

现在我们已经了解了如何调整矩阵乘法的基础知识,不用担心。您无需执行上面讨论的任何操作,因为已经有子程序可以执行此操作。正如评论中提到的,有 GotoBLAS、OpenBLAS、Intel 的 MKL 和 Apple 的 Accelerate 框架。 MKL/Accelerate 是专有的,但 OpenBLAS 是一个非常有竞争力的替代品。

这是一个很好的小例子,它在我的 Macintosh 上在几毫秒内将 2 个 8k x 8k 矩阵相乘:

#include <sys/time.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <Accelerate/Accelerate.h>

int SIZE = 8192;

typedef float point_t;

point_t* transpose(point_t* A) {    
    point_t* At = (point_t*) calloc(SIZE * SIZE, sizeof(point_t));    
    vDSP_mtrans(A, 1, At, 1, SIZE, SIZE);

    return At;
}

point_t* dot(point_t* A, point_t* B) {
    point_t* C = (point_t*) calloc(SIZE * SIZE, sizeof(point_t));       
    int i;    
    int step = (SIZE * SIZE / 4);

    cblas_sgemm (CblasRowMajor, 
       CblasNoTrans, CblasNoTrans, SIZE/4, SIZE, SIZE,
       1.0, &A[0], SIZE, B, SIZE, 0.0, &C[0], SIZE);

    cblas_sgemm (CblasRowMajor, 
       CblasNoTrans, CblasNoTrans, SIZE/4, SIZE, SIZE,
       1.0, &A[step], SIZE, B, SIZE, 0.0, &C[step], SIZE);

    cblas_sgemm (CblasRowMajor, 
       CblasNoTrans, CblasNoTrans, SIZE/4, SIZE, SIZE,
       1.0, &A[step * 2], SIZE, B, SIZE, 0.0, &C[step * 2], SIZE);

    cblas_sgemm (CblasRowMajor, 
       CblasNoTrans, CblasNoTrans, SIZE/4, SIZE, SIZE,
       1.0, &A[step * 3], SIZE, B, SIZE, 0.0, &C[step * 3], SIZE);      

    return C;
}

void print(point_t* A) {
    int i, j;
    for(i = 0; i < SIZE; i++) {
        for(j = 0; j < SIZE; j++) {
            printf("%f  ", A[i * SIZE + j]);
        }
        printf("\n");
    }
}

int main() {
    for(; SIZE <= 8192; SIZE *= 2) {
        point_t* A = (point_t*) calloc(SIZE * SIZE, sizeof(point_t));
        point_t* B = (point_t*) calloc(SIZE * SIZE, sizeof(point_t));

        srand(getpid());

        int i, j;
        for(i = 0; i < SIZE * SIZE; i++) {
            A[i] = ((point_t)rand() / (double)RAND_MAX);
            B[i] = ((point_t)rand() / (double)RAND_MAX);
        }

        struct timeval t1, t2;
        double elapsed_time;

        gettimeofday(&t1, NULL);
        point_t* C = dot(A, B);
        gettimeofday(&t2, NULL);

        elapsed_time = (t2.tv_sec - t1.tv_sec) * 1000.0;      // sec to ms
        elapsed_time += (t2.tv_usec - t1.tv_usec) / 1000.0;   // us to ms

        printf("Time taken for %d size matrix multiplication: %lf\n", SIZE, elapsed_time/1000.0);

        free(A);
        free(B);
        free(C);

    }
    return 0;
}

此时我还应该提到 SSE(Streaming SIMD 扩展),这基本上是您不应该做的事情,除非您使用过汇编。基本上,您正在 向量化 您的 C 代码,以使用向量而不是整数。这意味着您可以对数据块而不是单个值进行操作。编译器放弃并按原样翻译您的代码,而不进行自己的优化。如果做得好,它可以像以前一样加速你的代码——你甚至可以触及 O(n^2) 的理论底线!但是滥用 SSE 很容易,不幸的是大多数人都这样做了,导致最终结果比以前更糟。

我希望这能激励您更深入地挖掘。矩阵乘法的世界是一个广阔而迷人的世界。下面,我附上链接以供进一步阅读。

  1. OpenBLAS
  2. More about SSE
  3. Intel Intrinsics