使用 goroutine 进行矩阵乘法会降低性能

Matrix multiplication with goroutine drops performance

我正在通过 Go 中的 goroutines 优化矩阵乘法。

我的基准测试显示,每行或每元素引入并发会大大降低性能:

goos: darwin
goarch: amd64
BenchmarkMatrixDotNaive/A.MultNaive-8                            2000000               869 ns/op               0 B/op          0 allocs/op
BenchmarkMatrixDotNaive/A.ParalMultNaivePerRow-8                  100000             14467 ns/op              80 B/op          9 allocs/op
BenchmarkMatrixDotNaive/A.ParalMultNaivePerElem-8                  20000             77299 ns/op             528 B/op         65 allocs/op

我知道缓存局部性的一些基本先验知识,每个元素并发性降低性能是有道理的。但是,为什么即使在原始版本中,每行仍然会降低性能?

其实我也写了一个block/tiling优化,它的vanilla版本(没有goroutine并发)甚至比naive版本还差(这里不介绍,先关注naive)。

我做错了什么?为什么?这里怎么优化?

乘法:

package naive

import (
    "errors"
    "sync"
)

// Errors
var (
    ErrNumElements = errors.New("Error number of elements")
    ErrMatrixSize  = errors.New("Error size of matrix")
)

// Matrix is a 2d array
type Matrix struct {
    N    int
    data [][]float64
}

// New a size by size matrix
func New(size int) func(...float64) (*Matrix, error) {
    wg := sync.WaitGroup{}
    d := make([][]float64, size)
    for i := range d {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            d[i] = make([]float64, size)
        }(i)
    }
    wg.Wait()
    m := &Matrix{N: size, data: d}
    return func(es ...float64) (*Matrix, error) {
        if len(es) != size*size {
            return nil, ErrNumElements
        }
        for i := range es {
            wg.Add(1)
            go func(i int) {
                defer wg.Done()
                m.data[i/size][i%size] = es[i]
            }(i)
        }
        wg.Wait()
        return m, nil
    }
}

// At access element (i, j)
func (A *Matrix) At(i, j int) float64 {
    return A.data[i][j]
}

// Set set element (i, j) with val
func (A *Matrix) Set(i, j int, val float64) {
    A.data[i][j] = val
}

// MultNaive matrix multiplication O(n^3)
func (A *Matrix) MultNaive(B, C *Matrix) (err error) {
    var (
        i, j, k int
        sum     float64
        N       = A.N
    )

    if N != B.N || N != C.N {
        return ErrMatrixSize
    }

    for i = 0; i < N; i++ {
        for j = 0; j < N; j++ {
            sum = 0.0
            for k = 0; k < N; k++ {
                sum += A.At(i, k) * B.At(k, j)
            }
            C.Set(i, j, sum)
        }
    }
    return
}

// ParalMultNaivePerRow matrix multiplication O(n^3) in concurrency per row
func (A *Matrix) ParalMultNaivePerRow(B, C *Matrix) (err error) {
    var N = A.N

    if N != B.N || N != C.N {
        return ErrMatrixSize
    }

    wg := sync.WaitGroup{}
    for i := 0; i < N; i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            for j := 0; j < N; j++ {
                sum := 0.0
                for k := 0; k < N; k++ {
                    sum += A.At(i, k) * B.At(k, j)
                }
                C.Set(i, j, sum)
            }
        }(i)
    }
    wg.Wait()
    return
}

// ParalMultNaivePerElem matrix multiplication O(n^3) in concurrency per element
func (A *Matrix) ParalMultNaivePerElem(B, C *Matrix) (err error) {
    var N = A.N

    if N != B.N || N != C.N {
        return ErrMatrixSize
    }

    wg := sync.WaitGroup{}
    for i := 0; i < N; i++ {
        for j := 0; j < N; j++ {
            wg.Add(1)
            go func(i, j int) {
                defer wg.Done()
                sum := 0.0
                for k := 0; k < N; k++ {
                    sum += A.At(i, k) * B.At(k, j)
                }
                C.Set(i, j, sum)
            }(i, j)
        }
    }
    wg.Wait()
    return
}

基准:

package naive

import (
    "os"
    "runtime/trace"
    "testing"
)

type Dot func(B, C *Matrix) error

var (
    A = &Matrix{
        N: 8,
        data: [][]float64{
            []float64{1, 2, 3, 4, 5, 6, 7, 8},
            []float64{9, 1, 2, 3, 4, 5, 6, 7},
            []float64{8, 9, 1, 2, 3, 4, 5, 6},
            []float64{7, 8, 9, 1, 2, 3, 4, 5},
            []float64{6, 7, 8, 9, 1, 2, 3, 4},
            []float64{5, 6, 7, 8, 9, 1, 2, 3},
            []float64{4, 5, 6, 7, 8, 9, 1, 2},
            []float64{3, 4, 5, 6, 7, 8, 9, 0},
        },
    }
    B = &Matrix{
        N: 8,
        data: [][]float64{
            []float64{9, 8, 7, 6, 5, 4, 3, 2},
            []float64{1, 9, 8, 7, 6, 5, 4, 3},
            []float64{2, 1, 9, 8, 7, 6, 5, 4},
            []float64{3, 2, 1, 9, 8, 7, 6, 5},
            []float64{4, 3, 2, 1, 9, 8, 7, 6},
            []float64{5, 4, 3, 2, 1, 9, 8, 7},
            []float64{6, 5, 4, 3, 2, 1, 9, 8},
            []float64{7, 6, 5, 4, 3, 2, 1, 0},
        },
    }
    C = &Matrix{
        N: 8,
        data: [][]float64{
            []float64{0, 0, 0, 0, 0, 0, 0, 0},
            []float64{0, 0, 0, 0, 0, 0, 0, 0},
            []float64{0, 0, 0, 0, 0, 0, 0, 0},
            []float64{0, 0, 0, 0, 0, 0, 0, 0},
            []float64{0, 0, 0, 0, 0, 0, 0, 0},
            []float64{0, 0, 0, 0, 0, 0, 0, 0},
            []float64{0, 0, 0, 0, 0, 0, 0, 0},
            []float64{0, 0, 0, 0, 0, 0, 0, 0},
        },
    }
)

func BenchmarkMatrixDotNaive(b *testing.B) {
    f, _ := os.Create("bench.trace")
    defer f.Close()
    trace.Start(f)
    defer trace.Stop()

    tests := []struct {
        name string
        f    Dot
    }{
        {
            name: "A.MultNaive",
            f:    A.MultNaive,
        },
        {
            name: "A.ParalMultNaivePerRow",
            f:    A.ParalMultNaivePerRow,
        },
        {
            name: "A.ParalMultNaivePerElem",
            f:    A.ParalMultNaivePerElem,
        },
    }
    for _, tt := range tests {
        b.Run(tt.name, func(b *testing.B) {
            for i := 0; i < b.N; i++ {
                tt.f(B, C)
            }
        })
    }
}

执行 8x8 矩阵乘法是相对较小的工作。

Goroutines(虽然可能是轻量级的)确实有开销。如果他们所做的工作是 "small",启动、同步和丢弃它们的开销可能会超过利用多核/线程的性能增益,并且总体而言,您可能无法通过并发执行此类小任务来获得性能(见鬼,你甚至可能比不使用 goroutines 做得更糟)。测量。

如果我们将矩阵大小增加到 80x80,运行 我们已经在 ParalMultNaivePerRow:

的情况下看到一些性能提升
BenchmarkMatrixDotNaive/A.MultNaive-4               2000     1054775 ns/op
BenchmarkMatrixDotNaive/A.ParalMultNaivePerRow-4    2000      709367 ns/op
BenchmarkMatrixDotNaive/A.ParalMultNaivePerElem-4    100    10224927 ns/op

(正如您在结果中看到的,我有 4 个 CPU 核,运行 它在您的 8 核机器上可能显示出更多的性能提升。)

当行很小时,您正在使用 goroutines 做最少的工作,您可以通过不 "throwing" 离开 goroutines 一旦他们完成他们的 "tiny" 工作来提高性能,但是你可以 "reuse"他们。参见相关问题:

另请参阅相关/可能的重复项: