numpy中的批量矩阵乘法

Batch matrix multiplication in numpy

我有两个 numpy 数组 ab,形状分别为 [5, 5, 5][5, 5]。对于 ab,形状中的第一个条目是批量大小。当我执行矩阵乘法选项时,我得到一个形状为 [5, 5, 5] 的数组。一个MWE如下。

import numpy as np

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = a @ b
# c.shape is (5, 5, 5)

假设我要 运行 对批量大小进行循环,即 a[0] @ b[0].T,它会产生一个形状为 [5, 1] 的数组。最后,如果我沿轴 1 连接所有结果,我将得到一个形状为 [5, 5] 的结果数组。下面的代码更好地描述了这些行。

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = []
for i in range(5):
    c.append(a[i] @ b[i].T)
c = np.concatenate([d[:, None] for d in c], axis=1).T
# c.shape evaluates to be (5, 5)

不使用循环是否可以实现上述功能?例如,PyTorch 提供了一个名为 torch.bmm 的函数来计算它。谢谢。

b添加一个额外的维度以使矩阵乘法batch兼容并通过squeezing移除最后一个多余的维度:

c = np.matmul(a, b[:, :, None]).squeeze(-1)

或等价地:

c = (a @ b[:, :, None]).squeeze(-1)

在您的示例中,通过将 b 重塑为 5 x 5 x 1,两者都使 ab 的矩阵乘法变得合适。

您可以使用 numpy einsum 解决这个问题。

c = np.einsum('BNi,Bi ->BN', a, b)

Pytorch 也提供了这个 einsum 函数,语法略有变化。所以你可以很容易地解决它。它也可以轻松处理其他形状。

那么您就不必担心移调或压缩操作了。它还可以节省内存,因为不会在内部创建现有矩阵的副本。