如何将 3D 火炬张量切片为 2D 切片

How to slice 3D torch tensor into 2D slices

我正在处理 3D CT 医学数据,我正在尝试将其切成可输入 UNet 模型的 2D 切片。

我已经将数据加载到 torch 数据加载器中,当前每次迭代都会生成一个 4D 张量:

for batch_index, batch_samples in enumerate(train_loader):
    data, target = batch_samples['image'].float().cuda(), batch_samples['label'].float().cuda()
    print(data.size())
torch.Size([1, 333, 512, 512])
torch.Size([1, 356, 512, 512])

比如这张。我想遍历 333 个切片,然后是 356 个切片,这样模型每次都会收到手电筒大小 [1, 1, 512, 512]。

我希望是这样的:

for x in (data[:,x,:,:]):

可以,但它说我需要先定义 x。如何迭代 torch 张量中的特定维度?

只需指定维度:

for i in range(data.shape[1]):  # dim=1
    x = data[:, i, :, :]
    # [...]

如果您确实需要额外的维度,只需添加 .unsqueeze():

d = 1
for i in range(data.shape[d]):         # dim=1
    x = data[:, i, :, :].unsqueeze(d)  # same dim=1
    # [...]