class_weight 随机森林中的超参数改变了混淆矩阵中的样本数量

class_weight hyperparameter in Random Forest change the amounts of samples in confusion matrix

我目前正在研究一个随机森林分类模型,该模型包含 24,000 个样本,其中 20,000 个属于 class 0,4,000 个属于 class 1。我做了一个 train_test_split,其中 test_set 是整个数据集的 0.2test_set 中大约有 4,800 个样本)。由于我正在处理不平衡数据,我查看了旨在解决此问题的超参数 class_weight

我在设置 class_weight='balanced' 并查看训练集的 confusion_matrix 时遇到的问题我得到了类似的东西:

array([[13209, 747], [ 2776, 2468]])

如您所见,下方数组对应 False Negative = 2776 后跟 True Positive = 2468,而上方数组对应 True Negative = 13209 后跟 False Positive = 747。问题是根据 confusion_matrix 属于 class 1 的样本数量是 2,776 (False Negative) + 2,468 (True Positive),总计 5,244 samples 属于 class 1。这没有任何意义,因为整个数据集仅包含 4,000 个属于 class 1 的样本,其中只有 3,200 个样本在 train_set 中。它看起来像 confusion_matrix return 矩阵的 Transposed 版本,因为 training_set 中属于 class 1 的实际样本量应该总和为 3,200 train_set 中的样本和 test_set 中的 800 个样本。一般来说,正确的数字应该是 747 + 2468 加起来是 3,215,这是属于 class 1 的正确数量的样本。 有人可以解释一下我使用 class_weight 时发生了什么吗? confusion_matrix return 是矩阵的 transposed 版本是真的吗?我看错了吗? 我曾尝试寻找答案并访问了几个在某种程度上相似的问题,但其中 none 确实涵盖了这个问题。

这些是我查看的一些来源:

https://datascience.stackexchange.com/questions/11564/how-does-class-weights-work-in-randomforestclassifier

https://stats.stackexchange.com/questions/244630/difference-between-sample-weight-and-class-weight-randomforest-classifier

using sample_weight and class_weight in imbalanced dataset with RandomForest Classifier

如有任何帮助,我们将不胜感激。

docs复制玩具示例:

from sklearn.metrics import confusion_matrix

y_true = [0, 1, 0, 1]
y_pred = [1, 1, 1, 0]

tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
(tn, fp, fn, tp)
# (0, 2, 1, 1)

因此,您提供的混淆矩阵的读数似乎是正确的。

Is it true that the confusion_matrix returns a transposed version of the matrix?

如上例所示,不。但是一个非常容易(并且看似无辜)的错误可能是您交换了 y_truey_pred 参数的 order,这很重要;结果确实是一个转置矩阵:

# correct order of arguments:
confusion_matrix(y_true, y_pred)
# array([[0, 2],
#        [1, 1]])

# inverted (wrong) order of the arguments:
confusion_matrix(y_pred, y_true)
# array([[0, 1],
#        [2, 1]])

根据您提供的信息无法判断这是否是原因,这很好地提醒了您为什么应始终提供实际代码,而不是口头描述您的内容 想想 你的代码正在做...