''Boolean value of Tensor with more than one value is ambiguous'' 广播 torch Tensor 时

''Boolean value of Tensor with more than one value is ambiguous'' when broadcasting torch Tensor

我的objective是提取一个pytorch张量的维度,其索引不在给定列表中。我想使用广播来做到这一点,如下所示:

Sim = torch.rand((5, 5))
samples_idx = [0]  # the index of dim that I don't want to extract
a = torch.arange(Sim.size(0)) not in samples_idx
result = Sim[a]

我假设 a 是一个 True/Flase 的张量,维度为 5.But 我得到错误 RuntimeError: Boolean value of Tensor with more than one value is ambiguous。任何人都可以帮我指出哪里出了问题?谢谢

您可以通过从包含所有索引的集合中减去 samples_idx 来创建包含所需索引的集合:

>>> Sim = torch.rand(5, 5)
tensor([[0.9069, 0.3323, 0.8358, 0.3738, 0.3516],
        [0.1894, 0.5747, 0.0763, 0.8526, 0.2351],
        [0.0304, 0.7631, 0.3799, 0.9968, 0.6143],
        [0.0647, 0.2307, 0.4061, 0.9648, 0.0212],
        [0.8479, 0.6400, 0.0195, 0.2901, 0.4026]])

>>> samples_idx = [0]

以下基本上充当您的torch.arange not in sample_idx

>>> idx = set(range(len(Sim))) - set(samples_idx)
{1, 2, 3, 4}

然后用idx进行索引:

>>> Sim[tuple(idx),:]
tensor([[0.1894, 0.5747, 0.0763, 0.8526, 0.2351],
        [0.0304, 0.7631, 0.3799, 0.9968, 0.6143],
        [0.0647, 0.2307, 0.4061, 0.9648, 0.0212],
        [0.8479, 0.6400, 0.0195, 0.2901, 0.4026]])

“维度”和“指标”的概念存在误解。您想要的是过滤 Sim 并仅保留索引与给定规则匹配的行(第 0 维)。

以下是您可以这样做的方法:

Sim = torch.rand((5, 5))
samples_idx = [0]  # the index of dim that I don't want to extract
a = [v for v in range(Sim.size(0)) if v not in samples_idx]
result = Sim[a]

a 不是布尔张量,而是要保留的索引列表。然后您使用它在第 0 个维度(行)上索引 Sim

not in不是可以广播的操作,你应该使用常规的Python推导列表。

也许这有点偏离重点,但您也可以尝试使用布尔索引。

>>> Sim = torch.rand((5, 5))
tensor([[0.8128, 0.2024, 0.3673, 0.2038, 0.3549],
        [0.4652, 0.4304, 0.4987, 0.2378, 0.2803],
        [0.2227, 0.1466, 0.6736, 0.0929, 0.3635],
        [0.2218, 0.9078, 0.2633, 0.3935, 0.2199],
        [0.7007, 0.9650, 0.4192, 0.4781, 0.9864]])

>>> samples_idx = [0]
>>> a = torch.ones(Sim.size(0))
>>> a[samples_idx] = 0
>>> result = Sim[a.bool(), :]
tensor([[0.4652, 0.4304, 0.4987, 0.2378, 0.2803],
        [0.2227, 0.1466, 0.6736, 0.0929, 0.3635],
        [0.2218, 0.9078, 0.2633, 0.3935, 0.2199],
        [0.7007, 0.9650, 0.4192, 0.4781, 0.9864]])

这样您就不必遍历所有 samples_idx 列表来检查是否包含。