numpy中的以下操作是什么意思?

What is the meaning of the following operation in numpy?

我正在挖掘一段 numpy 代码,有一行我根本不明白:

W[:, :, None] * h[None, :, :] * diff[:, None, :]

其中 Whdiff 是 784x20、20x100 和 784x100 矩阵。乘法结果是 784x20x100 数组,但我不知道这个计算实际上做了什么,结果的含义是什么。

值得一提的是,该行来自机器学习相关代码,W对应神经网络层的权重数组,h 是层激活,diff 是网络目标和假设之间的差异(来自 Sida Wang's thesis 关于转换自动编码器)。

对于 NumPy 数组,* 对应于逐元素乘法。为了使其工作,两个数组必须是:

  • 形状相同
  • 这样一个数组可以 broadcast 到另一个

如果在配对每个数组的尾随维度时,每对中的长度相等或其中一个长度为 1,则一个数组可以广播到另一个数组。

例如,以下数组 AB 具有适合广播的形状:

A.shape == (20, 1, 3)
B.shape ==     (4, 3)

(3 等于 3 然后 A 中的下一个长度是 1 可以与任何长度配对。没关系B 的维度小于 A。)

要使两个不兼容的数组可以相互广播,可以将额外的维度插入一个或两个数组中。使用 Nonenp.newaxis 索引维度会将长度为 1 的额外维度插​​入到数组中。


我们来看问题中的例子。 Python 从左到右计算重复的乘法:

  • W[:, :, None] 的形状为 (784, 20, 1)
  • h[None, :, :] 的形状为 ( 1, 20, 100)

根据上面的解释和乘法 returns 具有形状 (784, 20, 100).

的数组,这些形状是可广播的
  • 最后一次乘法的数组形状,(784, 20, 100)
  • diff[:, None, :] 的形状为 (784, 1, 100)

这两个数组的这些形状是兼容的,因此第二次乘法成功。返回一个形状为 (784, 20, 100) 的数组。