批处理张量切片,切片 B x N x M 和 B x 1

batched tensor slice, slice B x N x M with B x 1

我有一个 B x M x N 张量 X,还有一个 B x 1 张量 Y,它对应于我要保留的维度 = 1 的张量 X 的索引。这个切片的 shorthand 是什么,这样我就可以避免循环?

基本上我想这样做:

Z = torch.zeros(B,N)

for i in range(B):
    Z[i] = X[i][Y[i]]

下面的代码与循环中的代码类似。不同之处在于,我们不是按顺序索引数组 ZXY,而是使用数组 i

并行索引它们
B, M, N = 13, 7, 19

X = np.random.randint(100, size= [B,M,N])
Y = np.random.randint(M  , size= [B,1])
Z = np.random.randint(100, size= [B,N])

i = np.arange(B)
Y = Y.ravel()    # reducing array to rank-1, for easy indexing

Z[i] = X[i,Y[i],:]

这段代码可以进一步简化为

-> Z[i] = X[i,Y[i],:]
-> Z[i] = X[i,Y[i]]
-> Z[i] = X[i,Y]
-> Z    = X[i,Y]

pytorch 等效代码

B, M, N = 5, 7, 3

X = torch.randint(100, size= [B,M,N])
Y = torch.randint(M  , size= [B,1])
Z = torch.randint(100, size= [B,N])

i = torch.arange(B)
Y = Y.ravel()

Z = X[i,Y]

答案由@Hammad is short and perfect for the job. Here's an alternative solution if you're interested in using some less known Pytorch built-ins. We will use torch.gather (similarly you can achieve this with numpy.take)提供。

torch.gather 背后的想法是构建一个新的张量——基于两个相同形状的张量,包含索引(这里 ~ Y)和值(这里 ~ X) .

执行的操作是Z[i][j][k] = X[i][Y[i][j][k]][k].

由于 X 的形状是 (B, M, N)Y 的形状是 (B, 1) 我们希望填补 Y 中的空白,这样 Y的形状变成了(B, 1, N).

这可以通过一些轴操作来实现:

>>> Y.expand(-1, N)[:, None] # expand to dim=1 to N and unsqueeze dim=1

torch.gather 的实际调用将是:

>>> X.gather(dim=1, index=Y.expand(-1, N)[:, None])

您可以通过添加 [:, 0].

将其重塑为 (B, N)

此功能在棘手的场景中非常有效...