仅对 Pytorch 中的某些值使用 torch.eq()

Use torch.eq() only for some value in Pytorch

有没有一种方法可以使用 torch.eq() 或类似的函数来计算基于元素的相等性但仅限于某些元素? 假设我需要知道两个张量中有多少个 1 相等,但我不关心其他数字。

知道怎么做吗?

假设我们有 2 个张量 AB 填充了随机元素,最后在某处填充了一些 1。张量 C 是你想要的结果:

A = torch.rand((2, 3, 3))
B = torch.rand((2, 3, 3))

# fill A and B with some 1s
...

C = (A == 1) * (B == 1)

使用以下张量我们得到:

(A) [[[ 0.6151,  1.0000,  0.6515],
         [ 0.3337,  0.4262,  0.0731],
         [ 0.4571,  0.2380,  1.0000]],

        [[ 1.0000,  0.1114,  0.8183],
         [ 0.9178,  1.0000,  1.0000],
         [ 0.8180,  0.8112,  0.2972]]]

(B) [[[ 0.4305,  1.0000,  0.5378],
         [ 0.4171,  0.4365,  0.2805],
         [ 0.1076,  0.1259,  0.9695]],

        [[ 1.0000,  0.0911,  1.0000],
         [ 0.6757,  0.5095,  0.4499],
         [ 0.5787,  1.0000,  1.0000]]]

(C) [[[ 0,  1,  0],
         [ 0,  0,  0],
         [ 0,  0,  0]],

        [[ 1,  0,  0],
         [ 0,  0,  0],
         [ 0,  0,  0]]]