矩阵乘法的最佳 numba 实现在很大程度上取决于矩阵大小

optimal numba implementations for matrix multiplication depends significantly on matrix size

这个问题与我之前发布的一个问题有关:
Python, numpy, einsum multiply a stack of matrices

我试图理解为什么在 乘以一堆矩阵 时以特定方式使用 Numba 时,我会得到加速。和以前一样,我放入一个 (500,201,2,2) 数组,沿第一个轴在末尾乘以 (2x2) 矩阵(所以 500 次乘法),得到一个 (201,2,2) 数组作为结果.

这里是 Python 代码:

from numba import jit  # numba 0.24, numpy 1.9.3, python 2.7.11

Arr = rand(500,201,2,2)

def loopMult(Arr):
    ArrMult = Arr[0]
    for i in range(1,len(Arr)):
        ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
    return ArrMult

@jit(nopython=True)
def loopMultJit(Arr):
    ArrMult = np.empty(shape=Arr.shape[1:], dtype=Arr.dtype)
    for i in range(0, Arr.shape[1]):
        ArrMult[i] = Arr[0, i]
        for j in range(1, Arr.shape[0]):
            ArrMult[i] = np.dot(ArrMult[i], Arr[j, i])
    return ArrMult

@jit(nopython=True)
def loopMultJit_2X2(Arr):
    ArrMult = np.empty(shape=Arr.shape[1:], dtype=Arr.dtype)
    for i in range(0, Arr.shape[1]):
        ArrMult[i] = Arr[0, i]
        for j in range(1, Arr.shape[0]):
            x1 = ArrMult[i,0,0] * Arr[j,i,0,0] + ArrMult[i,0,1] * Arr[j,i,1,0]
            y1 = ArrMult[i,0,0] * Arr[j,i,0,1] + ArrMult[i,0,1] * Arr[j,i,1,1]
            x2 = ArrMult[i,1,0] * Arr[j,i,0,0] + ArrMult[i,1,1] * Arr[j,i,1,0]
            y2 = ArrMult[i,1,0] * Arr[j,i,0,1] + ArrMult[i,1,1] * Arr[j,i,1,1]
            ArrMult[i,0,0] = x1
            ArrMult[i,0,1] = y1
            ArrMult[i,1,0] = x2
            ArrMult[i,1,1] = y2
    return ArrMult

A1 = loopMult(Arr)
A2 = loopMultJit(Arr)
A3 = loopMultJit_2X2(Arr)

print np.allclose(A1, A2)
print np.allclose(A1, A3)

%timeit loopMult(Arr)
%timeit loopMultJit(Arr)
%timeit loopMultJit_2X2(Arr)

这是输出:

True
True
10 loops, best of 3: 40.5 ms per loop
10 loops, best of 3: 36 ms per loop
1000 loops, best of 3: 808 µs per loop

在前面的问题中,接受的答案表明使用 f2py 可以在没有进行详细优化的情况下实现 8 倍的加速。在这里,使用 Numba,我在 einsum 循环中使用 numba 获得了大约 10% 的加速,但是如果我没有在循环中使用 np.dot,而是简单地手动执行 2x2 矩阵乘法,我获得了 45 倍的加速。为什么是这样?我应该提到我已经实现了这两个带有适当类型签名的 jit 函数作为 guvectorize 版本,它基本上提供了相同的加速因子,所以我把它们排除在外。迭代 201,500,2,2 矩阵的加速也很小。

2 条评论回复说加速只是由于 python 开销,我认为这是正确的。开销主要是函数调用,但也有 for 循环,并且 np.dot 有一些额外的开销。我设置了一个朴素的点积函数:

@jit(nopython=True)
def dot(mat1, mat2):
    s = 0
    mat = np.empty(shape=(mat1.shape[1], mat2.shape[0]), dtype=mat1.dtype)
    for r1 in range(mat1.shape[0]):
        for c2 in range(mat2.shape[1]):
            s = 0
            for j in range(mat2.shape[0]):
                s += mat1[r1,j] * mat2[j,c2]
            mat[r1,c2] = s
    return mat

然后我设置函数来乘以数组,一个调用点函数,一个在循环中内置点函数,这样它就可以在没有额外函数调用的情况下执行:

@jit(nopython=True)
def loopMultJit_dot(Arr):
    ArrMult = np.empty(shape=Arr.shape[1:], dtype=Arr.dtype)
    for i in range(0, Arr.shape[1]):
        ArrMult[i] = Arr[0, i]
        for j in range(1, Arr.shape[0]):
            ArrMult[i] = dot(ArrMult[i], Arr[j, i])
    return ArrMult

@jit(nopython=True)
def loopMultJit_dotInternal(Arr):
    ArrMult = np.empty(shape=Arr.shape[1:], dtype=Arr.dtype)
    for i in range(0, Arr.shape[1]):
        ArrMult[i] = Arr[0, i]
        for j in range(1, Arr.shape[0]):
            s = 0.0
            for r1 in range(ArrMult.shape[1]):
                for c2 in range(Arr.shape[3]):
                    s = 0.0
                    for r2 in range(Arr.shape[2]):
                        s += ArrMult[i,r1,r2] * Arr[j,i,r2,c2]
                    ArrMult[i,r1,c2] = s
    return ArrMult

然后我可以 运行 2 个比较:2x2 数组和 10x10 数组。通过这些,我大致了解了为函数调用支付的惩罚,特别是 np.dot 函数调用,以及 np.dot:

中 BLAS 优化的收益
print "2x2 Time Test:"
Arr = rand(500,201,2,2)
%timeit loopMult(Arr)
%timeit loopMultJit(Arr)
%timeit loopMultJit_2X2(Arr)
%timeit loopMultJit_dot(Arr)
%timeit loopMultJit_dotInternal(Arr)

print "10x10 Time Test:"
Arr = rand(500,201,10,10)
%timeit loopMult(Arr)
%timeit loopMultJit(Arr)
%timeit loopMultJit_dot(Arr)
%timeit loopMultJit_dotInternal(Arr)

产生:

2x2 Time Test:
10 loops, best of 3: 55.8 ms per loop  # einsum
10 loops, best of 3: 48.7 ms per loop  # np.dot
1000 loops, best of 3: 1.09 ms per loop  # 2x2
10 loops, best of 3: 28.3 ms per loop  # naive dot, separate function
100 loops, best of 3: 2.58 ms per loop  # naive dot internal

10x10 Time Test:
1 loop, best of 3: 499 ms per loop  # einsum
10 loops, best of 3: 91.3 ms per loop  # np.dot
10 loops, best of 3: 170 ms per loop  # naive dot, separate function
10 loops, best of 3: 161 ms per loop  # naive dot internal

我想带回家的信息是:

  • 如果您不使用 numba 或需要单行代码,einsum 是不错的选择,但对于矩阵乘法,有更快的选项
  • 如果您使用的是小矩阵,手动操作而不是调用单独的函数会更快
  • 对于大矩阵,发明 BLAS 是有原因的,事实上,在小至 10x10 的情况下,加速非常明显。