如何在 Torch 中有效地对 3 维输入向量应用线性变换?

How to effectively apply linear transform on 3-dimensional input vector in Torch?

假设我们有一个 DoubleTensor - size: 5x32x3000,我们想将其转换为 DoubleTensor - size: 5x32x100 以进一步输入。现在,我要做的是:

local seq = nn.Sequential()
seq:add(nn.SplitTable(1))
seq:add(nn.MapTable():add(nn.Linear(3000,100)))
seq:add(nn.JoinTable(1)):add(nn.View(5,32,100))

这看起来有点复杂,我觉得应该有更高效的方法。你能想出更好的解决方案吗?

我已经试过了,它会输出你想要的大小 (5, 32, 1000)

data = torch.Tensor(5, 32, 3000)
mul = torch.Tensor(3000, 1000)
res = torch.mm(data:view(5*32, 3000), mul):view(5, 32, 1000)
print(res:size())

另一种方式也可以是:

seq = nn.Sequential()
seq:add(nn.SplitTable(1)):add(nn.MapTable():add(nn.Linear(3000,100)))
seq:add(nn.JoinTable(1))