如何在 DecisionTreeClassifier 中为 multi-class 设置设置 class 权重

How to set class weights in DecisionTreeClassifier for multi-class setting

我正在使用 sklearn.tree.DecisionTreeClassifier 训练 3-class class化问题。

3class中的记录数如下:

A: 122038
B: 43626
C: 6678

当我训练 classifier 模型时,它无法学习 class - C。虽然效率达到 65-70% 但它完全忽略了 class C.

然后我开始了解 class_weight 参数,但我不确定如何在 multiclass 设置中使用它。

这是我的代码:(我使用了 balanced 但它的准确性更差)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
clf = tree.DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=1,class_weight='balanced')
clf = clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)

如何使用与 class 分布成比例的权重。

其次,有没有更好的方法来解决这个不平衡 class 问题以提高准确性。?

您还可以将值字典传递给 class_weight 参数以设置您自己的权重。例如体重 class 你可以做的一半:

class_weight={
    'A': 0.5,
    'B': 1.0,
    'C': 1.0
}

通过 class_weight='balanced' 它会自动设置与 class 频率成反比的权重。

可以在 class_weight 参数下的文档中找到更多信息: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

通常预计平衡 classes 会降低准确性。这就是为什么准确性通常被认为是不平衡数据集的不良指标。

您可以尝试 sklearn 包含的平衡准确性指标作为开始,但还有许多其他潜在指标可以尝试,这取决于您的最终目标。

https://scikit-learn.org/stable/modules/model_evaluation.html

如果您不熟悉 'confusion matrix' 及其相关值(如精度和召回率),那么我将从那里开始您的研究。

https://en.wikipedia.org/wiki/Precision_and_recall

https://en.wikipedia.org/wiki/Confusion_matrix

https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

“平衡”模式是开始的方式。

The “balanced” mode uses the values of y to automatically adjust weights inversely proportional to class frequencies in the input data as n_samples / (n_classes * np.bincount(y))


要手动定义权重,您需要 字典字典列表,具体取决于问题。


class_weight dict, list of dict or “balanced”, default=None

Weights associated with classes in the form {class_label: weight}. If None, all classes are supposed to have weight one. For multi-output problems, a list of dicts can be provided in the same order as the columns of y.

Note that for multioutput (including multilabel) weights should be defined for > each class of every column in its own dict. For example, for four-class multilabel > classification weights should be [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of [{1:1}, {2:5}, {3:1}, {4:1}].


示例:

如果classA的频率是10%,classB的频率是90%:

clf = tree.DecisionTreeClassifier(class_weight={A:9,B:1})