R中的不平衡数据集、分类树和成本矩阵
Unbalanced dataset, Classification tree and cost matrix in R
我正在尝试创建一个分类模型来预测以下两个 类 之一:"Hit" 或 "Miss"。
数据集包含大约 80% "Hits" 因此它是高度不平衡的,因此分类树(ctree from party package)等模型选择将所有结果预测为 "Hit" 并获得 80% 的准确率.
我尝试过欠采样和 SMOTE 算法但没有成功。
当模型将 "Miss" 分类为 "Hit" 时,如何更改成本矩阵以惩罚模型?
您可以使用 ctree
的 weights
参数来做到这一点。既然你不提供任何数据,我就用假数据来说明。
library(party)
## Some bogus data
set.seed(42)
class = factor(sample(1:2, 500, replace=TRUE, prob=c(0.8, 0.2)) )
x1 = rnorm(500)
x2 = rnorm(500, 0.7, 0.9)
x = ifelse(class == 1, x1, x2)
y1 = rnorm(500)
y2 = rnorm(500, 0.7, 0.9)
y = ifelse(class == 1, y1, y2)
Imbalanced = data.frame(x,y,class)
只需对该数据使用 ctree
就可以 class 将所有数据验证为 class 1.
CT1 = ctree(class ~ ., data=Imbalanced)
table(predict(CT1))
1 2
500 0
但是如果你设置权重,你可以让它找到更多的class2数据。
W = ifelse(class==1, 1, 2)
CT2 = ctree(class ~ ., data=Imbalanced, weights=W)
table(predict(CT2), class)
class
1 2
1 336 44
2 63 57
请注意,总体准确度下降了 ,但我们得到了更多的 class 2 点正确 class 化。如果你使用一个非常大的权重因子,你可以获得几乎所有的 class 2 点(以更大的整体精度损失为代价)。
W = ifelse(class==1, 1, 5)
CT3 = ctree(class ~ ., data=Imbalanced, weights=W)
table(predict(CT3), class)
class
1 2
1 178 4
2 221 97
我正在尝试创建一个分类模型来预测以下两个 类 之一:"Hit" 或 "Miss"。
数据集包含大约 80% "Hits" 因此它是高度不平衡的,因此分类树(ctree from party package)等模型选择将所有结果预测为 "Hit" 并获得 80% 的准确率.
我尝试过欠采样和 SMOTE 算法但没有成功。
当模型将 "Miss" 分类为 "Hit" 时,如何更改成本矩阵以惩罚模型?
您可以使用 ctree
的 weights
参数来做到这一点。既然你不提供任何数据,我就用假数据来说明。
library(party)
## Some bogus data
set.seed(42)
class = factor(sample(1:2, 500, replace=TRUE, prob=c(0.8, 0.2)) )
x1 = rnorm(500)
x2 = rnorm(500, 0.7, 0.9)
x = ifelse(class == 1, x1, x2)
y1 = rnorm(500)
y2 = rnorm(500, 0.7, 0.9)
y = ifelse(class == 1, y1, y2)
Imbalanced = data.frame(x,y,class)
只需对该数据使用 ctree
就可以 class 将所有数据验证为 class 1.
CT1 = ctree(class ~ ., data=Imbalanced)
table(predict(CT1))
1 2
500 0
但是如果你设置权重,你可以让它找到更多的class2数据。
W = ifelse(class==1, 1, 2)
CT2 = ctree(class ~ ., data=Imbalanced, weights=W)
table(predict(CT2), class)
class
1 2
1 336 44
2 63 57
请注意,总体准确度下降了 ,但我们得到了更多的 class 2 点正确 class 化。如果你使用一个非常大的权重因子,你可以获得几乎所有的 class 2 点(以更大的整体精度损失为代价)。
W = ifelse(class==1, 1, 5)
CT3 = ctree(class ~ ., data=Imbalanced, weights=W)
table(predict(CT3), class)
class
1 2
1 178 4
2 221 97