在 PyTorch 中使用 3D 张量索引对 4D 张量进行切片
Slicing a 4D tensor with a 3D tensor-index in PyTorch
我有一个 4D 张量(它恰好是一个 三批 56x56 图像的堆栈,其中每批有 16 张图像)大小为 [16, 3, 56, 56]。我的目标是 select 每个像素的那三个批次中正确的一个(我的索引图的大小为 [16, 56, 56])并得到我想要的图像。
现在,我想 select 这三个批次中的特定批次图像,其值如
[[[ 0, 0, 2, ..., 0, 0, 0],
[ 0, 0, 2, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
...,
[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 2, 0, ..., 0, 0, 0],
[ 0, 2, 2, ..., 0, 0, 0]],
[[ 0, 2, 0, ..., 1, 1, 0],
[ 0, 2, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 2, 0],
...,
[ 0, 0, 0, ..., 0, 2, 0],
[ 0, 0, 2, ..., 0, 2, 0],
[ 0, 0, 2, ..., 0, 0, 0]]]
因此,对于 0,值将从第一批 select 编辑,其中 1 和 2 表示我想要 select 第二批和第三批的值。
以下是指数的一些可视化效果,每种颜色代表另一批次。
我试图转置 4D 张量以匹配我的索引的维度,但它没有用。它所做的只是给我一份我尝试过的维度的副本select。均值
tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())
产出
torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])
而我需要的形状是
torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])
举个例子,如果我尝试 select 只为批次中的第一张图像设置正确的值
fourD[0,indices].size()
我的形状像
torch.Size([16, 56, 56, 56, 56])
更不用说我在整个张量上尝试此操作时出现内存不足错误。
对于使用这些索引select我图像中每个像素的这三个批次中的任何一个。
,我很感激任何帮助
注:
我试过这个选项
outs[indices[:,None,:,:]].size()
还有那个returns
torch.Size([16, 1, 56, 56, 3, 56, 56])
编辑:torch.take 没有多大帮助,因为它将输入张量视为一维数组。
原来 PyTorch 中有一个函数具有我正在搜索的功能。
torch.gather(fourD, 1, indices.unsqueeze(1))
完成任务。
很好地解释了 gather 的作用。
我有一个 4D 张量(它恰好是一个 三批 56x56 图像的堆栈,其中每批有 16 张图像)大小为 [16, 3, 56, 56]。我的目标是 select 每个像素的那三个批次中正确的一个(我的索引图的大小为 [16, 56, 56])并得到我想要的图像。
现在,我想 select 这三个批次中的特定批次图像,其值如
[[[ 0, 0, 2, ..., 0, 0, 0],
[ 0, 0, 2, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
...,
[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 2, 0, ..., 0, 0, 0],
[ 0, 2, 2, ..., 0, 0, 0]],
[[ 0, 2, 0, ..., 1, 1, 0],
[ 0, 2, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 2, 0],
...,
[ 0, 0, 0, ..., 0, 2, 0],
[ 0, 0, 2, ..., 0, 2, 0],
[ 0, 0, 2, ..., 0, 0, 0]]]
因此,对于 0,值将从第一批 select 编辑,其中 1 和 2 表示我想要 select 第二批和第三批的值。
以下是指数的一些可视化效果,每种颜色代表另一批次。
我试图转置 4D 张量以匹配我的索引的维度,但它没有用。它所做的只是给我一份我尝试过的维度的副本select。均值
tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())
产出
torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])
而我需要的形状是
torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])
举个例子,如果我尝试 select 只为批次中的第一张图像设置正确的值
fourD[0,indices].size()
我的形状像
torch.Size([16, 56, 56, 56, 56])
更不用说我在整个张量上尝试此操作时出现内存不足错误。
对于使用这些索引select我图像中每个像素的这三个批次中的任何一个。
,我很感激任何帮助注:
我试过这个选项
outs[indices[:,None,:,:]].size()
还有那个returns
torch.Size([16, 1, 56, 56, 3, 56, 56])
编辑:torch.take 没有多大帮助,因为它将输入张量视为一维数组。
原来 PyTorch 中有一个函数具有我正在搜索的功能。
torch.gather(fourD, 1, indices.unsqueeze(1))
完成任务。