GEMM 使用 Numpy einsum

GEMM using Numpy einsum

单个 numpy einsum 语句能否复制 gemm 功能?标量和矩阵乘法看起来很简单,但我还没有找到如何让“+”工作。如果它更简单, D = alpha * A * B + beta * C 是可以接受的(实际上更可取)

alpha = 2
beta = 3
A = np.arange(9).reshape(3, 3)
B = A + 1
C = B + 1

left_part = alpha*np.dot(A, B)
print(left_part)
left_part = np.einsum(',ij,jk->ik', alpha, A, B)
print(left_part)

这里似乎有些混乱:np.einsum 处理可以按以下形式转换的操作:broadcast–multiply–reduce。按元素求和不属于其范围。

之所以需要这种乘法运算,是因为写出这些运算 "naively" 可能会很快超出内存或计算资源。例如,考虑矩阵乘法:

import numpy as np
x, y = np.ones((2, 2000, 2000))

# explicit loop - ridiculously slow
a = sum(x[:,j,np.newaxis] * y[j,:] for j in range(2000))

# explicit broadcast-multiply-reduce: throws MemoryError
a = (x[:,:,np.newaxis] * y[:,np.newaxis,:]).sum(1)

# einsum or dot: fast and memory-saving
a = np.einsum('ij,jk->ik', x, y)

爱因斯坦约定 因式分解 加法,所以你 可以将类似 BLAS 的问题简单地写成:

d = np.einsum(',ij,jk->ik', alpha, a, b) + np.einsum(',ik', beta, c)

具有最小的内存开销(如果您真的关心内存,您可以将其中的大部分重写为就地操作)和恒定的运行时开销(两次 python-to-C 调用的成本)。

因此,关于性能,恭敬地,这对我来说似乎是一个过早优化的案例:您是否真的验证过将类似 GEMM 的操作拆分为两个单独的 numpy 调用是您代码中的瓶颈?如果确实如此,那么我建议如下(按照参与程度增加的顺序):

  1. 仔细尝试!,scipy.linalg.blas.dgemm。如果你得到我会感到惊讶 明显更好的性能,因为 dgemms 通常只有 积木本身。

  2. 试试表达式编译器(本质上你是在提议 这样的事情)比如 Theano.

  3. 使用 Cython 或 C 编写您自己的 generalised ufunc