如何以 F1 作为度量而不是准确性来训练非二元分类 rpart?

How to train non-binary classification rpart with F1 as metric instead of accuracy?

我将插入符号用于我的非二进制(三个 类)决策树分类。我的数据集有偏差,所以我想使用 F1 而不是准确性来进行训练和测试。我该如何设置?

对于 MWE 让我们预测钻石数据集中的切割:

library(ggplot2)
library(caret)
inTrain <- createDataPartition(diamonds$cut, p=0.75, list=FALSE)
training <- diamonds[inTrain,]
testing <- diamonds[-inTrain,]
fitModel <- train(cut ~ ., training, method = "rpart")

这里F1怎么用?

http://topepo.github.io/caret/training.html 的页面详细介绍了如何为训练函数创建新指标 -

您需要创建一个具有三个参数的新函数 -

  • 数据 - "is a reference for a data frame or matrix with columns called obs and pred for the observed and predicted outcome values (either numeric data for regression or character values for classification)"
  • lev - "is a character string that has the outcome factor levels taken from the training data. For regression, a value of NULL is passed into the function."
  • 姓名 - "is a character string for the model being used"

该函数应计算数据对象中观察到的标签和预测标签的F-score,并根据度量命名结果-

例如计算精度的函数

summaryStats <- function (data, lev = NULL, model = NULL) {
  cor <- sum(data$pred==data$obs)
  incor  <- sum(data$pred!=data$obs)
  out <- cor/(cor + incor)
  names(out) <- c("acc")
  out
} 

然后创建一个新的 trainControl 对象并训练您的模型 --

fitControl <- trainControl(summaryFunction = summaryStats)
fitModel <- train(cut ~ ., training, trControl = fitControl, metric = "acc", maximize=TRUE)