Pytorch is throwing an error RuntimeError: result type Float can't be cast to the desired output type Long

Pytorch is throwing an error RuntimeError: result type Float can't be cast to the desired output type Long

我应该如何摆脱以下错误?

>>> t = torch.tensor([[1, 0, 1, 1]]).T
>>> p = torch.rand(4,1)
>>> torch.nn.BCEWithLogitsLoss()(p, t)

以上代码抛出以下错误:

RuntimeError: 结果类型 Float 无法转换为所需的输出类型 Long

BCEWithLogitsLoss 要求它的目标是 float 张量,而不是 long。所以你应该通过 dtype=torch.float32:

指定 t 张量的类型
import torch

t = torch.tensor([[1, 0, 1, 1]], dtype=torch.float32).T
p = torch.rand(4,1)
loss_fn = torch.nn.BCEWithLogitsLoss()

print(loss_fn(p, t))

输出:

tensor(0.5207)