使用整数打印 PyTorch 张量时如何设置精度?

How can I set precision when printing a PyTorch tensor with integers?

我有:

mask = mask_model(input_spectrogram)
mask = torch.round(mask).float()
torch.set_printoptions(precision=4)
print(mask.size(), input_spectrogram.size(), masked_input.size())
print(mask[0][-40:], mask[0].size())

这会打印:

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.], grad_fn=<SliceBackward>) torch.Size([400])

但我希望它打印 1.0000 而不是 1.。为什么我的 set_precision 不这样做?即使我转换为 float()?

不幸的是,这就是 PyTorch 显示张量值的方式。你的代码工作正常,如果你这样做 print(mask * 1.1),你可以看到当张量值不能再表示为整数时,PyTorch 确实打印出 4 个十进制值。