从 PyTorch 检索原始数据 nn.Embedding

Retrieving original data from PyTorch nn.Embedding

我正在将包含 5 个类别(例如汽车、公共汽车...)的数据框传递到 nn.Embedding

当我执行 embedding.parameters() 时,我可以看到有 5 个张量,但我如何知道哪个索引对应于原始输入(例如汽车、公共汽车...)?

你不能,因为张量是未命名的(只能命名维度,参见 PyTorch's Named Tensors)。 您必须将名称保存在单独的数据容器中,例如(此处为 4 类别):

import pandas as pd
import torch

df = pd.DataFrame(
    {
        "bus": [1.0, 2, 3, 4, 5],
        "car": [6.0, 7, 8, 9, 10],
        "bike": [11.0, 12, 13, 14, 15],
        "train": [16.0, 17, 18, 19, 20],
    }
)

df_data = df.to_numpy().T
df_names = list(df)

embedding = torch.nn.Embedding(df_data.shape[0], df_data.shape[1])
embedding.weight.data = torch.from_numpy(df_data)

现在您可以简单地将它与您想要的任何索引一起使用:

index = 1
embedding(torch.tensor(index)), df_names[index]

这会给你 (tensor[6, 7, 8, 9, 10], "car") 所以数据和相应的列名。