在 Pytorch 中,如何使用 BoolTensor 掩码将张量切片到多个 dims 中?

In Pytorch how to slice tensor across multiple dims with BoolTensor masks?

我想在 Pytorch 中使用 BoolTensor 索引对多维张量进行切片。我期望索引张量保留索引为真的部分,而索引为假的部分被切掉。

我的代码是这样的

import torch
a = torch.zeros((5, 50, 5, 50))

tr_indices = torch.zeros((50), dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices

print(a[:, tr_indices].shape)
print(a[:, tr_indices, :, val_indices].shape)

我希望 a[:, tr_indices, :, val_indices] 的形状为 [5, 25, 5, 25],但它 returns [25, 5, 5]。结果是

torch.Size([5, 25, 5, 50])
torch.Size([25, 5, 5])

我很困惑。谁能解释一下为什么?

PyTorch 继承了其高级索引行为。像这样切片两次应该可以达到你想要的输出:

a[:, tr_indices][..., val_indices]