Pytorch lightning metrics: ValueError: preds and target must have same number of dimensions, or one additional dimension for preds

Pytorch lightning metrics: ValueError: preds and target must have same number of dimensions, or one additional dimension for preds

谷歌搜索让你无所适从,所以我决定通过将其发布为可搜索问题来帮助未来的我和其他人。


def __init__():
    ...
    self.val_acc = pl.metrics.Accuracy()

def validation_step(self, batch, batch_index):
    ...
    self.val_acc.update(log_probs, label_batch)

给予

ValueError: preds and target must have same number of dimensions, or one additional dimension for preds

log_probs.shape == (16, 4)label_batch.shape == (16, 4)

有什么问题?

pl.metrics.Accuracy() 需要一批 dtype=torch.long 标签,而不是单热编码标签。

因此,它应该被喂养

self.val_acc.update(log_probs, torch.argmax(label_batch.squeeze(), dim=1))


这与torch.nn.CrossEntropyLoss

相同