在 Pytorch 中,如何使用 BoolTensor 掩码将张量切片到多个 dims 中?
In Pytorch how to slice tensor across multiple dims with BoolTensor masks?
我想在 Pytorch 中使用 BoolTensor 索引对多维张量进行切片。我期望索引张量保留索引为真的部分,而索引为假的部分被切掉。
我的代码是这样的
import torch
a = torch.zeros((5, 50, 5, 50))
tr_indices = torch.zeros((50), dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices
print(a[:, tr_indices].shape)
print(a[:, tr_indices, :, val_indices].shape)
我希望 a[:, tr_indices, :, val_indices]
的形状为 [5, 25, 5, 25]
,但它 returns [25, 5, 5]
。结果是
torch.Size([5, 25, 5, 50])
torch.Size([25, 5, 5])
我很困惑。谁能解释一下为什么?
PyTorch 继承了其高级索引行为。像这样切片两次应该可以达到你想要的输出:
a[:, tr_indices][..., val_indices]
我想在 Pytorch 中使用 BoolTensor 索引对多维张量进行切片。我期望索引张量保留索引为真的部分,而索引为假的部分被切掉。
我的代码是这样的
import torch
a = torch.zeros((5, 50, 5, 50))
tr_indices = torch.zeros((50), dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices
print(a[:, tr_indices].shape)
print(a[:, tr_indices, :, val_indices].shape)
我希望 a[:, tr_indices, :, val_indices]
的形状为 [5, 25, 5, 25]
,但它 returns [25, 5, 5]
。结果是
torch.Size([5, 25, 5, 50])
torch.Size([25, 5, 5])
我很困惑。谁能解释一下为什么?
PyTorch 继承了其高级索引行为
a[:, tr_indices][..., val_indices]