NumPy/PyTorch 使用来自 argmax 调用的索引数组进行切片

NumPy/PyTorch slicing using an array of indices from an argmax call

我有以下tensors / ndarrays我正在操作。

a_intents 形状 (b, n_i) - 位置 ij 处的标量是批处理中 i-th 元素的意图 j 的激活。

形状 (b, n_i, d_m)

u_intents - 索引 ij 处维度 d_m 的矢量是 [=19= 的意图 j 的姿势矢量] 批次中的元素。

我想获得具有最大激活标量的意图索引,我这样做了:

 max_activations = argmax(a_intents, dim=-1, keepdim=False)

现在,使用这些索引,我想在 u_intents 中获得相应的向量。

max_activation_poses = u_intents[?, ?,:]

如何使用 max_activations 仅指示 dim 1 上的那些索引?我的直觉告诉我,如果我这样做,我将以不正确的形状结束

[:, max_activations, :]

我想要获得的形状是 (b, d_m) - 对于批处理中的每个元素,与 highest activation 具有相同索引的向量。

谢谢

如果您将 u_intents 向量视为二维向量,并将每个 argmax 索引偏移其批次索引乘以元素数,就很容易了。

# dummy values for demonstration
>>> b, n_i, d_m = 2, 3, 5
>>> a_intents = torch.rand(b, n_i)
>>> a_intents
tensor([[0.1733, 0.9965, 0.4790],
        [0.6056, 0.4040, 0.0236]])
>>> u_intents = torch.rand(b, n_i, d_m)
>>> u_intents
tensor([[[0.3301, 0.8153, 0.1356, 0.6623, 0.4385],
         [0.1902, 0.1748, 0.4131, 0.3887, 0.5363],
         [0.1211, 0.5773, 0.2405, 0.6313, 0.2064]],
        [[0.2592, 0.5127, 0.7301, 0.8883, 0.5665],
         [0.2767, 0.6545, 0.7595, 0.2677, 0.5163],
         [0.8158, 0.4940, 0.0492, 0.0911, 0.0465]]])
# add to each index the batch start
>>> max_activations = a_intents.argmax(dim=-1) + torch.arange(0, b*n_i, step=n_i)
# elements 1 and 3 of 0..5
>>> max_activations
tensor([1, 3])   
>>> poses = u_intents.view(b*n_i, d_m).index_select(0, max_activations)
# tensor of shape (b, d_m) correctly indexing the maxima.
>>> poses
tensor([[0.1902, 0.1748, 0.4131, 0.3887, 0.5363],
        [0.2592, 0.5127, 0.7301, 0.8883, 0.5665]])