使用多个维度处理 nn.Linear
Working of nn.Linear with multiple dimensions
PyTorch 的 nn.Linear(in_features, out_features)
接受大小为 (N_batch, N_1, N_2, ..., N_end)
的张量,其中 N_end = in_features
。输出是一个大小为 (N_batch, N_1, N_2, ..., out_features)
.
的张量
我不太清楚它在以下情况下的表现:
- 如果
v
是一行,输出将为A^Tv+b
- 如果
M
是一个矩阵,则将其视为一批行,对每一行v
,执行A^Tv+b,然后将所有内容放回矩阵表格
- 如果输入张量的阶数更高呢?假设输入张量的维度为
(N_batch, 4, 5, 6, 7)
。该层是否会输出一批大小为 N_batch
的 (1, 1, 1, N_out)
形状的向量,所有形状都变成 (N_batch, 4, 5, 6, N_out)
张量?
对于 1 维,输入是具有暗淡 in_features
的向量,输出是 out_features
。按照你说的计算
对于 2 维,输入是 N_batch
个暗淡 in_features
的向量,输出是 N_batch
个暗暗 out_features
的向量。按照你说的计算
对于 3 个维度,输入是 (N_batch, C, in_features)
,也就是 N_batch
个矩阵,每个矩阵有 C
行带有 dim in_features
的向量,输出是 N_batch
矩阵,每个矩阵有 C
行带有暗淡 out_features
.
的向量
如果你觉得很难想到3维的情况。一种简单的方法是将形状展平为 (N_batch * C, in_features)
,这样输入就变成 N_batch * C
行带有暗淡 in_features
的向量,这与二维情况相同。这个flatten部分不涉及计算,只是重新排列输入。
所以在你的情况 3 中,它的输出是 N_batch
个 (3, 4, 5, 6, N_out)
个向量,或者在用 dim N_out
重新排列它的 N_batch * 3 * 4 * 5 * 6
个向量之后。你所有 1 个暗淡的形状都不正确,因为总共只有 N_batch * N_out
个元素。
如果你深入研究 pytorch 的内部 C 实现,你会发现 matmul
实现实际上压平了维度,正如我所描述的 native matmul 这是 [=32= 使用的确切函数]
PyTorch 的 nn.Linear(in_features, out_features)
接受大小为 (N_batch, N_1, N_2, ..., N_end)
的张量,其中 N_end = in_features
。输出是一个大小为 (N_batch, N_1, N_2, ..., out_features)
.
我不太清楚它在以下情况下的表现:
- 如果
v
是一行,输出将为A^Tv+b - 如果
M
是一个矩阵,则将其视为一批行,对每一行v
,执行A^Tv+b,然后将所有内容放回矩阵表格 - 如果输入张量的阶数更高呢?假设输入张量的维度为
(N_batch, 4, 5, 6, 7)
。该层是否会输出一批大小为N_batch
的(1, 1, 1, N_out)
形状的向量,所有形状都变成(N_batch, 4, 5, 6, N_out)
张量?
对于 1 维,输入是具有暗淡 in_features
的向量,输出是 out_features
。按照你说的计算
对于 2 维,输入是 N_batch
个暗淡 in_features
的向量,输出是 N_batch
个暗暗 out_features
的向量。按照你说的计算
对于 3 个维度,输入是 (N_batch, C, in_features)
,也就是 N_batch
个矩阵,每个矩阵有 C
行带有 dim in_features
的向量,输出是 N_batch
矩阵,每个矩阵有 C
行带有暗淡 out_features
.
如果你觉得很难想到3维的情况。一种简单的方法是将形状展平为 (N_batch * C, in_features)
,这样输入就变成 N_batch * C
行带有暗淡 in_features
的向量,这与二维情况相同。这个flatten部分不涉及计算,只是重新排列输入。
所以在你的情况 3 中,它的输出是 N_batch
个 (3, 4, 5, 6, N_out)
个向量,或者在用 dim N_out
重新排列它的 N_batch * 3 * 4 * 5 * 6
个向量之后。你所有 1 个暗淡的形状都不正确,因为总共只有 N_batch * N_out
个元素。
如果你深入研究 pytorch 的内部 C 实现,你会发现 matmul
实现实际上压平了维度,正如我所描述的 native matmul 这是 [=32= 使用的确切函数]