具有 3d 输入的 Pytorch 交叉熵损失

Pytorch crossentropy loss with 3d input

我有一个输出大小为 (batch_size, max_len, num_classes) 的 3D 张量的网络。我的真值是 (batch_size, max_len) 的形状。如果我确实对标签执行单热编码,它将是 (batch_size, max_len, num_classes) 的形状,即 max_len 中的值是 [0, num_classes] 范围内的整数。由于原代码太长,我写了一个更简单的版本来复现原来的错误。

criterion = nn.CrossEntropyLoss()
batch_size = 32
max_len = 350
num_classes = 1000
pred = torch.randn([batch_size, max_len, num_classes])
label = torch.randint(0, num_classes,[batch_size, max_len])
pred = nn.Softmax(dim = 2)(pred)
criterion(pred, label)

pred和label的shape分别是torch.Size([32, 350, 1000])torch.Size([32, 350])

遇到的错误是

ValueError: Expected target size (32, 1000), got torch.Size([32, 350, 1000])

如果我一次性编码标签来计算损失

x = nn.functional.one_hot(label)
criterion(pred, x)

它会抛出以下错误

ValueError: Expected target size (32, 1000), got torch.Size([32, 350, 1000])

Pytorch documentation 开始,CrossEntropyLoss 期望其输入的形状为 (N, C, ...),因此第二维始终为 类 的数字。如果您将 preds 重塑为 (batch_size, num_classes, max_len).

大小,您的代码应该可以工作