java 中的并行矩阵乘法

Parallel matrix multiplication in java

我正在尝试用多线程实现矩阵乘法。一切似乎都正常工作,但是,它比通常的算法慢得多。这是我的代码

public class Main {
    private static int nRows = 500; //number of rows and columns in matrices
    private static int[][] matrix1 = new int[nRows][nRows]; //first matrix for multiplication
    private static int[][] matrix2 = new int[nRows][nRows]; //second matrix for multiplication
    private static int[][] result1 = new int[nRows][nRows]; //result from linear matrix multiplication
    private static int[][] result2 = new int[nRows][nRows]; //result from parallel matrix multiplication

    private static Thread[][] pool = new Thread[nRows][nRows]; //array of threads

    //method used for transposing a matrix to get its column easily
    public static int[][] transpose(int[][] matrix) {
        int[][] newMatrix = new int[matrix[0].length][matrix.length];
        for (int i = 0; i < matrix[0].length; i++) {
            for (int j = 0; j < matrix.length; j++) {
                newMatrix[i][j] = matrix[j][i];
            }
        }
        return newMatrix;
    }

    public static void main(String[] args) {
        //initializing input matrices (setting all elements = 1)
        for (int i = 0; i < nRows; i++) {
            for (int j = 0; j < nRows; j++) {
                matrix1[i][j] = 1;
                matrix2[i][j] = 1;
            }
        }

        long start;
        long end;

        System.out.println("Linear algorithm");
        start = System.currentTimeMillis();

        //linear multiplication algorithm
        for (int i = 0; i < nRows; i++) {
            for (int j = 0; j < nRows; j++) {
                int temp = 0;
                for (int k = 0; k < nRows; k++) {
                    temp += matrix1[i][k] * matrix2[k][j];
                }
                result1[i][j] = temp;
            }
        }

        //show result
//        for(int i=0;i<nRows;i++){
//            for(int j=0;j<nRows;j++){
//                System.out.print(result1[i][j] + " ");
//            }
//            System.out.println();
//        }

        end = System.currentTimeMillis();
        System.out.println("Time with linear algorithm: " + (end - start));

        //--------------------

        System.out.println("Parallel algorithm");
        start = System.currentTimeMillis();

        int[][] matrix3 = transpose(matrix2); //get a transpose copy of second matrix

        for (int i = 0; i < nRows; i++) {
            for (int j = 0; j < nRows; j++) {
                pool[i][j] = new myThread(matrix1[i], matrix3[j], i, j); //creating a thread for each element
                pool[i][j].start(); //starting a thread
            }
        }

        for (int i = 0; i < nRows; i++) {
            for (int j = 0; j < nRows; j++) {
                try {
                    pool[i][j].join(); //waiting for the thread to finish its job
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }

        //show the result
//        for(int i=0;i<nRows;i++){
//            for(int j=0;j<nRows;j++){
//                System.out.print(result2[i][j] + " ");
//            }
//            System.out.println();
//        }

        end = System.currentTimeMillis();
        System.out.println("Time with parallel algorithm: " + (end - start));
    }

    //class, where parallel multiplication is implemented
    private static class myThread extends Thread {
        private int[] row = new int[nRows]; //row for multiplication
        private int[] col = new int[nRows]; //column for multiplication
        private int i;  //row index of the element in resulting matrix
        private int j; //column index of the element in resulting matrix

        //constructor
        public myThread(int[] r, int[] c, int i, int j) {
            row = r;
            col = c;
            this.i = i;
            this.j = j;
        }

        public void run() {
            int temp = 0;
            for (int k = 0; k < nRows; k++) {
                temp += row[k] * col[k]; //getting the element by multiplying row and column of two matrices
            }
            result2[i][j] = temp; //writing the resulting element to the resulting matrix
        }
    }
}

在这里,我为结果矩阵中的每个元素创建了一个新线程。我不是将这些线程写入一个数组,启动它们,最后等待它们完成工作。我已经看到了一些实现,其中整个输入矩阵(它们两个)将作为参数提供给线程。然而,我的任务是想出一个算法,其中只给出一行和一列(对于这个特定元素是必需的)。

测量经过的时间后,我得到以下结果

Linear algorithm
Time with linear algorithm: 557
Parallel algorithm
Time with parallel algorithm: 38262

我做错了什么?提前致谢!

您编写的代码在 GPU 上运行良好,其中线程的概念非常不同,开销基本为零。在基于 CPU 的系统上,生成线程是一个异常缓慢的操作,只有当您可以将此开销分摊到 大量 的计算工作时才有意义。

这里有一些一般性建议,可以帮助您为 CPU 编写更好的并行算法:

  • 对于计算量大的任务,使用与物理执行单元(核心)一样多的线程。除非存在大量内存延迟,否则诸如超线程之类的 SMT 技术帮助不大。对于适合 L1 和 L2 CPU 缓存的小矩阵,延迟非常低,SMT 没有任何好处。当多个线程共享同一个内核时,OS 必须在两者之间进行上下文切换,这会增加开销并可能破坏缓存。
  • 保持并行化粒度尽可能粗,以便最大化每个线程的工作量。让每个线程对连续的行/列块进行操作,而不是让每个线程进行一行 x 列操作。您可以尝试仅并行化外部循环,即仅并行化第一个矩阵的行。
  • 保持线程数取决于硬件属性(内核数)并且独立于问题大小。为每一行和每一列生成一个单独的线程会使开销与问题大小成线性关系,从性能的角度来看这确实很糟糕。
  • 避免虚假分享。当不同内核上的两个或多个线程 运行 写入同一缓存行中的内存位置时,就会发生这种情况。当一个线程更新其核心的缓存时,更改会传播并使具有相同缓存行的其他核心的缓存无效,从而迫使它们重新获取数据。在您的情况下,result2 的 16 个连续值落在同一个缓存行中(x86 和 ARM 上的缓存行长 64 个字节,int 是 4 个字节)并且由 16 个不同的线程写入。使用临时求和变量以某种方式缓解了这个问题——当错误共享在内部(-most)循环中重复发生时,它会更加严重。
  • 当工作项的数量超过线程数并且每个线程将多次获得工作时,对重复任务使用线程池。在你的例子中,你给每个线程一个单独的工作项,所以这不是真正的池化。

总而言之,启动与物理核心一样多的线程,并让它们处理输入矩阵的大连续块。

并行处理只适用于大量处理器。如果您没有足够的处理器来拆分工作,也没有足够的处理器来处理负载,那么并行化不会让您大吃一惊。并行化可能会使处理变慢。

如果您有 P 个处理器和 8(P) 个并发请求,那么每个请求使用一个线程通常对吞吐量更有效。分解的盈亏平衡点刚好大于 8(P),具体取决于应用。这里的逻辑很简单。如果您有 P 个处理器可用,并且相应地拆分了您的工作,但前面还有数百个其他任务,那么拆分的意义何在?按顺序处理每个请求可能会更快。

矩阵乘法是一个真正的内存消耗。如果没有足够的内存,并行化可能会使处理变慢。

也就是说,拆分和合并的好方法太冗长,无法在此处复制。我维护着一个开源 divide-and-conquer product,它的内置函数之一是矩阵乘法。看看并尽你所能。