PyTorch nn.module 不会取消批处理操作

PyTorch nn.module won't unbatch operations

我有一个 nn.Module,其 forward 函数接受两个输入。在函数内部,我将其中一个输入 x1 乘以一组可训练参数,然后将它们与另一个输入 x2.

连接起来
class ConcatMe(nn.Module):
    def __init__(self, pad_len, emb_size):
        super(ConcatMe, self).__init__()
        self.W = nn.Parameter(torch.randn(pad_len, emb_size).to(DEVICE), requires_grad=True)
        self.emb_size = emb_size
     
    def forward(self, x1: Tensor, x2: Tensor):
        cat = self.W * torch.reshape(x2, (1, -1, 1))
        return torch.cat((x1, cat), dim=-1)

根据我的理解,人们应该能够在 PyTorch 的 nn.Module 中编写操作,就像我们对批大小为 1 的输入所做的那样。出于某种原因,情况并非如此。我收到一个错误,表明 PyTorch 仍在考虑 batch_size.

x1 =  torch.randn(100,2,512)
x2 = torch.randint(10, (2,1))
concat = ConcatMe(100, 512)
concat(x1, x2)

-----------------------------------------------------------------------------------
File "/home/my/file/path.py, line 0, in forward
    cat = self.W * torch.reshape(x2, (1, -1, 1))
RuntimeError: The size of tensor a (100) must match the size of tensor b (2) at non-singleton dimension 1

我制作了一个 for 循环来修补问题,如下所示:

class ConcatMe(nn.Module):
    def __init__(self, pad_len, emb_size):
        super(ConcatMe, self).__init__()
        self.W = nn.Parameter(torch.randn(pad_len, emb_size).to(DEVICE), requires_grad=True)
        self.emb_size = emb_size
     
    def forward(self, x1: Tensor, x2: Tensor):
        batch_size = x2.shape[0]
        cat = torch.ones(x1.shape).to(DEVICE)

        for i in range(batch_size):
            cat[:, i, :] = self.W * x2[i]

        return torch.cat((x1, cat), dim=-1)

但我觉得有更优雅的解决方案。这与我在 nn.Module 中创建参数有关吗?如果是这样,我可以实施什么不需要 for 循环的解决方案。

From my understanding, one is supposed to be able to write operations in PyTorch's nn.Modules like we would for inputs with a batch size of 1.

我不确定你从哪里得到这个假设,这绝对是 不是 正确的 - 相反:你总是需要以他们可以处理的方式编写它们任意批量维度的一般情况。

从你的第二个实现来看,你似乎在尝试将两个维度不兼容的张量相乘。所以为了解决这个问题,你必须定义

        self.W = torch.nn.Parameter(torch.randn(pad_len, 1, emb_size), requires_grad=True)

要更好地理解此类事情,了解 broadcasting

会有所帮助