如何查看一个热编码的 class 名称?

how to see the class name of one hot encoded?

我有一个包含两列的 CSV 文件:'Text' 一条推文及其“标签”。每条推文可能属于以下 4 个类别之一:仇恨、中立、反仇恨和非亚裔侵略。 我通过 Python 中的以下代码对训练和测试向量进行了 One Hot Encode Y values:

encoder = LabelEncoder()
y_train = encoder.fit_transform(train['Label'].values)
y_train = to_categorical(y_train) 
y_test = encoder.fit_transform(test['Label'].values)
y_test = to_categorical(y_test)

如果你打印第一个索引:

print(y_train[0])

答案是:

[0. 1. 0. 0.]

我们知道,每个Label被转换为一个长度为4的向量,其中每个位置对应一个Labelclass。我怎样才能找到每个class的位置?

例如:仇恨=0,反仇恨=1,...

首先,考虑encoder class在训练集上拟合然后变换,但只变换测试集!我建议使用 inverse_transform 方法来检索您的原始标签。

from sklearn import preprocessing
le = preprocessing.LabelEncoder()
le.fit(['Hate', 'Neutral', 'CounterHate and Non-Asian', 'Aggression'])
print(list(le.classes_))
print(le.transform(['CounterHate and Non-Asian', 'Hate', 'Neutral']))
print(le.inverse_transform([2, 2, 1]))

输出:

['Aggression', 'CounterHate and Non-Asian', 'Hate', 'Neutral']
[1 2 3]
['Hate' 'Hate' 'CounterHate and Non-Asian']