将 one-hot 编码维度转换为 1 位置的索引

Convert one-hot encoded dimension into the index of position of 1

我有一个三维张量[batch_size, sequence_length, number_of_tokens]。 最后一个维度是单热编码的。我想接收一个二维张量,其中 sequence_lengthnumber_of_tokens 维度的“1”的索引位置组成。

例如,将形状为 (2, 3, 4):

的张量
[[[0, 1, 0, 0]
[1, 0, 0, 0]
[0, 0, 0, 1]]
[[1, 0, 0, 0]
[1, 0, 0, 0]
[0, 0, 1, 0]]]

变成形状为 (2, 3) 的张量,其中 number_of_tokens 维度被转换为 1 的位置:

[[1, 0, 3]
[0, 0, 2]]

我这样做是为了准备模型结果,以便在计算损失时与参考答案进行比较,我希望这是正确的方法。

你可以通过连续的列表理解来做你想做的事:

x=[[[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1]],
[[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0]]]

y=[[ell2.index(1) for ell2 in ell1] for ell1 in x]

print(y) # prints [[1, 0, 3], [0, 0, 2]]

迭代主张量的元素和每个元素,returns该元素的分量中的“1”索引列表。

只需执行:

res = x.argmax(axis = 2)

如果你的原始张量是中指定的,你可以绕过one-hot编码直接使用argmax:

t = torch.rand(2, 3, 4)
t = t.argmax(dim=2)