如何为具有布尔值的张量中的行应用条件

How to apply conditions for rows in a tensor where there is boolean values

我有以下张量:

predictions = torch.tensor([[ True, False, False],
                            [False, False,  True],
                            [False,  True,  True],
                            [ True, False, False]])

我像下面这样沿轴应用了条件。

new_pred= []

if predictions == ([True,False,False]):
       new_pred = torch.Tensor(0)
if predictions == ([False,False,True]):
       new_pred = torch.Tensor(2)
if predictions == ([False,True,True]):
       new_pred = torch.Tensor(2)

所以我希望最终输出 (new_pred) 是: 张量([0, 2, 2, 0])

但是 new_pred 张量的 [] 是空白的。我认为我的逻辑一定有缺陷,因为 new_pred 中没有存储任何内容。有人可以帮我准确地写出这个逻辑吗?

predictions的类型是torch.Tensor,而([True, False, False])是一个列表,首先要保证两边的类型相同。

predictions == torch.tensor([True,False,False])
>>> tensor([[ True, True, True],
            [False, True, False],
            [False, False, False],
            [True, True, True]])

然后,你仍然在比较 2d 张量和 1d 张量,这在 if 语句中是不明确的,解决这个问题的一个简单方法是写一个 for 循环,比较predictions 的每一行的条件和 append 结果到 new_pred 列表。请注意,您将比较两个大小为 3 的布尔张量,因此,您必须确保所有单元格的比较结果都是 True

predictions = torch.tensor([[ True, False, False],
                            [False, False,  True],
                            [False,  True,  True],
                            [ True, False, False]])

conditions = torch.tensor([[True,False,False], 
                            [False,False,True],
                            [False,True,True]])
new_predict = []
for index in range(predictions.size(0)):
    if (predictions[index] == conditions[0]).all():
        new_predict.append(0)
    # ...

或者,您可以使用切片来实现您的预​​期结果,而无需任何 for 循环。