使用 pyTorch 张量沿一个特定维度与 3 维张量进行索引
Indexing using pyTorch tensors along one specific dimension with 3 dimensional tensor
我有 2 个张量:
A 具有形状(批次、序列、词汇)
和 B 的形状为 (batch, sequence)。
A = torch.tensor([[[ 1., 2., 3.],
[ 5., 6., 7.]],
[[ 9., 10., 11.],
[13., 14., 15.]]])
B = torch.tensor([[0, 2],
[1, 0]])
我想获得以下信息:
C = torch.zeros_like(B)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
C[i,j] = A[i,j,B[i,j]]
但是以矢量化的方式。我尝试了 torch.gather 和其他东西,但我无法让它工作。
谁能帮帮我?
>>> import torch
>>> A = torch.tensor([[[ 1., 2., 3.],
... [ 5., 6., 7.]],
...
... [[ 9., 10., 11.],
... [13., 14., 15.]]])
>>> B = torch.tensor([[0, 2],
... [1, 0]])
>>> A.shape
torch.Size([2, 2, 3])
>>> B.shape
torch.Size([2, 2])
>>> C = torch.zeros_like(B)
>>> for i in range(B.shape[0]):
... for j in range(B.shape[1]):
... C[i,j] = A[i,j,B[i,j]]
...
>>> C
tensor([[ 1, 7],
[10, 13]])
>>> torch.gather(A, -1, B.unsqueeze(-1))
tensor([[[ 1.],
[ 7.]],
[[10.],
[13.]]])
>>> torch.gather(A, -1, B.unsqueeze(-1)).shape
torch.Size([2, 2, 1])
>>> torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)
tensor([[ 1., 7.],
[10., 13.]])
您好,您可以使用 torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)
。
A
和 B.unsqueeze(-1)
之间的第一个 -1
表示您要拾取元素的维度。
B.unsqueeze(-1) 中的第二个 -1
是在 B 上加一个 dim,使两个张量具有相同的 dims,否则你会得到 RuntimeError: Index tensor must have the same number of dimensions as input tensor
.
最后的-1
是将结果从torch.Size([2, 2, 1])
重塑为torch.Size([2, 2])
我有 2 个张量:
A 具有形状(批次、序列、词汇) 和 B 的形状为 (batch, sequence)。
A = torch.tensor([[[ 1., 2., 3.],
[ 5., 6., 7.]],
[[ 9., 10., 11.],
[13., 14., 15.]]])
B = torch.tensor([[0, 2],
[1, 0]])
我想获得以下信息:
C = torch.zeros_like(B)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
C[i,j] = A[i,j,B[i,j]]
但是以矢量化的方式。我尝试了 torch.gather 和其他东西,但我无法让它工作。 谁能帮帮我?
>>> import torch
>>> A = torch.tensor([[[ 1., 2., 3.],
... [ 5., 6., 7.]],
...
... [[ 9., 10., 11.],
... [13., 14., 15.]]])
>>> B = torch.tensor([[0, 2],
... [1, 0]])
>>> A.shape
torch.Size([2, 2, 3])
>>> B.shape
torch.Size([2, 2])
>>> C = torch.zeros_like(B)
>>> for i in range(B.shape[0]):
... for j in range(B.shape[1]):
... C[i,j] = A[i,j,B[i,j]]
...
>>> C
tensor([[ 1, 7],
[10, 13]])
>>> torch.gather(A, -1, B.unsqueeze(-1))
tensor([[[ 1.],
[ 7.]],
[[10.],
[13.]]])
>>> torch.gather(A, -1, B.unsqueeze(-1)).shape
torch.Size([2, 2, 1])
>>> torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)
tensor([[ 1., 7.],
[10., 13.]])
您好,您可以使用 torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)
。
A
和 B.unsqueeze(-1)
之间的第一个 -1
表示您要拾取元素的维度。
B.unsqueeze(-1) 中的第二个 -1
是在 B 上加一个 dim,使两个张量具有相同的 dims,否则你会得到 RuntimeError: Index tensor must have the same number of dimensions as input tensor
.
最后的-1
是将结果从torch.Size([2, 2, 1])
重塑为torch.Size([2, 2])