括号引起的矩阵乘法执行时间差
Execution time difference in matrix multiplication caused by parentheses
给定两个一维 numpy
数组 a
和 b
以及
N = 100000
a = np.randn(N)
b = np.randn(N)
为什么下面两个表达式的执行时间相差很大:
# expression 1
c = a @ a * b @ b
# expression 2
c = (a @ a) * (b @ b)
使用 Jupyter Notebook 的 %timeit
魔法,我得到以下结果:
%timeit a @ a * b @ b
223 µs ± 6.97 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
和
%timeit (a @ a) * (b @ b)
17.4 µs ± 27.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
在这两个版本中,您都对长度为 N 的向量进行两个点积。然而,此外,第一个解决方案执行 N 次乘法,而第二个解决方案只需要一次。
a @ a * b @ b
等同于 ((a @ a) * b) @ b
或
aa = a @ a # N multiplications and additions -> scalar
aab = aa * b # N multiplications -> vector
aabb = aab @ b # N multiplications and additions -> scalar
(a @ a) * (b @ b)
等同于
aa = a @ a # N multiplications and additions -> scalar
bb = b @ b # N multiplications and additions -> scalar
aabb = aa * bb # 1 multiplication -> scalar
众所周知,矩阵乘法性能取决于如何设置括号。存在通过利用这一事实来优化 matrix chain multiplication 的算法。
更新:正如我刚刚了解到的,numpy有一个优化多个矩阵乘法的函数:numpy.linalg.multidot
给定两个一维 numpy
数组 a
和 b
以及
N = 100000
a = np.randn(N)
b = np.randn(N)
为什么下面两个表达式的执行时间相差很大:
# expression 1
c = a @ a * b @ b
# expression 2
c = (a @ a) * (b @ b)
使用 Jupyter Notebook 的 %timeit
魔法,我得到以下结果:
%timeit a @ a * b @ b
223 µs ± 6.97 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
和
%timeit (a @ a) * (b @ b)
17.4 µs ± 27.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
在这两个版本中,您都对长度为 N 的向量进行两个点积。然而,此外,第一个解决方案执行 N 次乘法,而第二个解决方案只需要一次。
a @ a * b @ b
等同于 ((a @ a) * b) @ b
或
aa = a @ a # N multiplications and additions -> scalar
aab = aa * b # N multiplications -> vector
aabb = aab @ b # N multiplications and additions -> scalar
(a @ a) * (b @ b)
等同于
aa = a @ a # N multiplications and additions -> scalar
bb = b @ b # N multiplications and additions -> scalar
aabb = aa * bb # 1 multiplication -> scalar
众所周知,矩阵乘法性能取决于如何设置括号。存在通过利用这一事实来优化 matrix chain multiplication 的算法。
更新:正如我刚刚了解到的,numpy有一个优化多个矩阵乘法的函数:numpy.linalg.multidot