找到前三个相关类别及其对应的概率

Finding the top three relevant category and its corresponding probabilities

从下面的脚本中,我找到了多 class 文本 class化问题中的最高概率及其对应的类别。如何在不使用循环的情况下以最有效的方式找到最高的前 3 个预测概率及其对应的类别。

probabilities = classifier.predict_proba(X_test)
max_probabilities = probabilities.max(axis=1)
order=np.argsort(probabilities, axis=1)
classification=(classifier.classes_[order[:, -1:]])
print(accuracy_score(classification,y_test))

提前致谢。 (我有大约 50 个类别,我想为每个叙述提取 50 个类别中最相关的前 3 个类别,并将它们显示在数据框中)

您已经完成了这里的大部分艰苦工作,只是缺少一些 numpy foo 来完成它。你的线路

order = np.argsort(probabilities, axis=1)

包含排序概率的索引,因此每个样本的 [[lowest_prob_class_1, ..., highest_prob_class_1]...]。你曾经用 order[:, -1:] 给你的 class 证明,即最高概率​​的索引 class。所以为了获得前三名 classes 我们可以做一个简单的改变

top_3_classes = classifier.classes_[order[:, -3:]]

然后得到相应的概率我们可以使用

top_3_probabilities = probabilities[np.repeat(np.arange(order.shape[0]), 3),
                                    order[:, -3:].flatten()].reshape(order.shape[0], 3)