nn.Linear层在pytorch中在附加维度上的应用

Application of nn.Linear layer in pytorch on additional dimentions

pytorch中的全连接层(nn.Linear)是如何应用在"additional dimensions"上的? documentation 说,它可以用于将张量 (N,*,in_features) 连接到 (N,*,out_features),其中 N 在批处理中的示例数量中,因此它是无关紧要的,并且* 是那些 "additional" 维度。这是否意味着单个层是使用附加维度中所有可能的切片进行训练的,或者是为每个切片训练的单独层还是其他不同的东西?

linear.weight 中学习了 in_features * out_features 个参数,在 linear.bias 中学习了 out_features 个参数。您可以将 nn.Linear 视为

  1. 将张量重塑为某些 (N', in_features),其中 N'N* 描述的所有维度的乘积:input_2d = input.reshape(-1, in_features)
  2. 应用标准矩阵-矩阵乘法output_2d = linear.weight @ input_2d
  3. 添加偏差 output_2d += linear.bias.reshape(1, in_features)(注意我们在所有 N' 维度上传播它)
  4. 重塑输出以具有与 input 相同的维度,除了最后一个:output = output_2d.reshape(*input.shape[:-1], out_features)
  5. return output

因此,前导维度 N* 维度的处理方式相同。该文档使 N 明确地让您知道输入必须 至少 2d,但可以是任意多维的。