conv1d pytorch 如何对一系列字符或帧进行操作?

How conv1d pytorch operates on a sequence of characters or frames?

我理解应用于图像的卷积滤波器(例如,具有 3 个输入通道的 224x224 图像通过 56 个 5x5 conv 的总滤波器转换为具有 56 个输出通道的 224x224 图像)。关键是有 56 个不同的过滤器,每个过滤器具有 5x5x3 的权重,最终产生输出图像 224x224, 56(逗号后的术语是输出通道)。

但我似乎无法理解 conv1d 过滤器如何在字符序列的 seq2seq 模型中工作。我正在查看的模型之一 https://arxiv.org/pdf/1712.05884.pdf 有一个“post-net 层由 512 个形状为 5×1 的滤波器组成” 在频谱图上运行frame 80-d(表示frame中有80个不同的float值),filter的结果是512-d frame.

在pytorch中打印的图层显示为:

(conv): Conv1d(80, 512, kernel_size=(5,), stride=(1,), padding=(2,))

这一层的参数显示为:

postnet.convolutions.0.0.conv.weight : 512x80x5 = 204800

介绍说明

基本上 Conv1d 就像 Conv2d 但不是 "sliding" 横跨图像的矩形 window (比如 3x3 代表 kernel_size=3 ) 你 "slide" 穿过向量(比如长度 256),kernel(比如大小 3)。这是 in_channelsout_channels 等于 1 的情况,这是基本的。

下面你可以看到 Conv1d 滑过 3 in_channels (x-axis, y-axis, z-axis) 滑过 seconds 步骤。

您可以向内核添加深度(就像您对 2D5x5x3 立方体的卷积所做的那样),这也是 5x35 是内核大小,3in_channels的个数)。现在可能有 out_channels 个正方形(例如 56 out_channels),所以最终生成的序列是 56 x sequence_length.

问题

[...] post-net layer is comprised of 512 filters with shape 5×1" that operates on a spectrogram frame 80-d (means 80 different float values in the frame), and the result of filter is a 512-d frame.

所以你的输入是80d(而不是上面的3轴),kernel_size是相同的(5),out_channels512。所以输入可能看起来像这样:[64, 80, 256](对于 [batch, in_channels, length]),输出将是 [64, 512, 256](假设两边都使用了 3 的填充)。

I don't understand what in_channels, out_channels mean in pytorch conv1d definition as in images I can easily understand what in-channels/out-channels mean, but for sequence of 80-float values frames I'm at loss. What do they mean in the context of seq2seq model like this above?

我猜上面已经回答了。要点是:序列不是 80 浮点值! 序列可以是任意长度(就像将图像传递给卷积时图像可以是任意大小一样),这里 in_channels80.

How do 512, 5x1 filters on 80 float values produce 512 float values?**

512 x sequence_length 值在 80 x sequence_length 输入上产生。

Shouldn't the weights in this layer instead be 512*5*1 as it only has 512 filters each of which is 5x1?

在 PyTorch 中,在您的情况下,权重的形状为 torch.Size([512, 80, 5])。如果您有一个输入通道,它们可能是 torch.Size([512, 1, 5]),但在这种情况下,它们有 80 个。