如何使用 PyTorch 沿特定维度进行一次热编码?

How do I one hot encode along a specific dimension using PyTorch?

我有一个大小为 [3, 15, 136] 的张量,其中:

我想使用 tokens 维 (136) 中的概率来单热我的张量。为此,我想提取序列长度中每个字母的标记维度,并将 1 置于最大可能性并将所有其他标记标记为 0.

你可以使用 PyTorch 的 one_hot 函数来实现:

import torch.nn.functional as F

t = torch.rand(3, 15, 136)

F.one_hot(t.argmax(dim=2), 136)