PyTorch:比较三个张量?

PyTorch: compare three tensors?

我有三个布尔掩码张量,我想创建一个布尔掩码,如果值在三个张量中匹配,则它是 1,否则 0.

我试了torch.where(A == B == C, 1, 0),但是好像不支持这个

您可以使用:

((A == B) & (B == C))

如果需要,您始终可以将布尔张量转换为适当的类型:

((A == B) & (B == C)).to(float)

据我所知,张量基本上是绑定到设备的 NumPy 数组。如果对您的应用程序来说不是太贵并且您可以负担得起 CPU,您可以简单地将它转换为 NumPy 并通过比较执行您需要的操作。

torch.eq运算符only supports binary tensor comparisons,因此需要进行两次比较:

(A==B) & (B==C)