在 PyTorch 中使用 3D 张量索引对 4D 张量进行切片

Slicing a 4D tensor with a 3D tensor-index in PyTorch

我有一个 4D 张量(它恰好是一个 三批 56x56 图像的堆栈,其中每批有 16 张图像)大小为 [16, 3, 56, 56]。我的目标是 select 每个像素的那三个批次中正确的一个(我的索引图的大小为 [16, 56, 56])并得到我想要的图像。

现在,我想 select 这三个批次中的特定批次图像,其值如

       [[[ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  2,  2,  ...,  0,  0,  0]],

        [[ 0,  2,  0,  ...,  1,  1,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  2,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  0,  0]]]

因此,对于 0,值将从第一批 select 编辑,其中 1 和 2 表示我想要 select 第二批和第三批的值。

以下是指数的一些可视化效果,每种颜色代表另一批次。

我试图转置 4D 张量以匹配我的索引的维度,但它没有用。它所做的只是给我一份我尝试过的维度的副本select。均值

tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())

产出

torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])

而我需要的形状是

torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])

举个例子,如果我尝试 select 只为批次中的第一张图像设置正确的值

fourD[0,indices].size()

我的形状像

torch.Size([16, 56, 56, 56, 56])

更不用说我在整个张量上尝试此操作时出现内存不足错误。

对于使用这些索引select我图像中每个像素的这三个批次中的任何一个

,我很感激任何帮助

注:

我试过这个选项

outs[indices[:,None,:,:]].size()

还有那个returns

torch.Size([16, 1, 56, 56, 3, 56, 56])

编辑:torch.take 没有多大帮助,因为它将输入张量视为一维数组。

原来 PyTorch 中有一个函数具有我正在搜索的功能。

torch.gather(fourD, 1, indices.unsqueeze(1)) 

完成任务。

很好地解释了 gather 的作用。