在pytorch中是否有一个函数可以找到布尔矩阵中每一行的或?

Is there a function that finds the OR of every row in a boolean matrix in pytorch?

我有一个大小为 n x m 的矩阵 A,所有条目都是布尔值。我希望我的所有计算都在 GPU 上进行,并且我将矩阵 A 存储为张量,每个条目都是 pytorch 的 bool 数据类型。我希望单个向量 b 的输出是一个 1 x m 张量,它存储 A 中所有行的或。

我想要的:
矩阵 =
[a1,1, a1,2, , a1,3, ... ,a1,m]
[a2,1, a2,2, , a2,3, ... ,a2,m]
...
[an,1, an,2, , an,3, ... ,an,m]


b = [b1, b2, , b3, , . .., bm]

s.t。 bi = a1,i | a2,i | a3,i | ... | an,i
其中 |是Pytorch中的OR运算符

本质上我想要一个应用行或列布尔运算的函数。我知道 | .__OR__ 可以使用,Pytorch 的 OR 函数将两个布尔张量作为输入,我需要遍历所有行以获得我想要的 b 向量。

因为 OR 是可交换的 (a|(b|c) = (a|b)|c),我认为 pytorch 会有一些不错的函数可以通过执行 | 来加速它。并行操作或以某种分而治之的方法进行操作,而不是使用循环来执行此操作。欢迎任何想法或参考,以加快使用 pytorch 应用交际 row/column 明智布尔运算的过程。最好所有操作都在 GPU 上完成。

torch.any and torch.all

两者都采用 dim 参数,因此您可以计算 or/and 行。