在 tidymodel 的 collect_predictions() 上使用 caret::confusionMatrix() 进行模型评估时出错

Getting error on using caret::confusionMatrix() on collect_predictions() from tidymodel for model evaluation

我是 R 的新手,使用 tidymodels 创建了一个 classification 模型,下面是 collect_predictions(model)

的结果
collect_predictions(members_final) %>% print()

# A tibble: 19,126 x 6
   id               .pred_died .pred_survived  .row .pred_class died    
   <chr>                 <dbl>          <dbl> <int> <fct>       <fct>   
 1 train/test split      0.285          0.715     5 survived    survived
 2 train/test split      0.269          0.731     6 survived    survived
 3 train/test split      0.298          0.702     7 survived    survived
 4 train/test split      0.276          0.724     8 survived    survived
 5 train/test split      0.251          0.749    10 survived    survived
 6 train/test split      0.124          0.876    18 survived    survived
 7 train/test split      0.127          0.873    21 survived    survived
 8 train/test split      0.171          0.829    26 survived    survived
 9 train/test split      0.158          0.842    30 survived    survived
10 train/test split      0.150          0.850    32 survived    survived
# … with 19,116 more rows

它适用于 yardstick 函数:

collect_predictions(members_final) %>%
  conf_mat(died, .pred_class)

          Truth
Prediction  died survived
  died       196     7207
  survived    90    11633

但是当我将 collect_predictions 传送到 caret::confusionMatrix() 时它不起作用

collect_predictions(members_final) %>% 
  caret::confusionMatrix(as.factor(died), as.factor(.pred_class))

############## output #################
Error: `data` and `reference` should be factors with the same levels.
Traceback:

1. collect_predictions(members_final) %>% caret::confusionMatrix(as.factor(died), 
 .     as.factor(.pred_class))

2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))

3. eval(quote(`_fseq`(`_lhs`)), env, env)

4. eval(quote(`_fseq`(`_lhs`)), env, env)

我不确定这里出了什么问题,所以如何修复它以使用插入符号评估?

使用插入符计算的目的是找出 positive/negative class.

是否有任何其他方法可以找出 positive/neg classes (levels(df$class) 是否正确找出正 classes 用于模型?)

如果您有预测,例如 collect_predictions() 的输出,那么您不想将其通过管道传输到插入符号的函数中。它不像 yardstick 函数那样将数据作为第一个参数。相反,将参数作为向量传递:

library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
data("two_class_example", package = "yardstick")

confusionMatrix(two_class_example$predicted, two_class_example$truth)
#> Confusion Matrix and Statistics
#> 
#>           Reference
#> Prediction Class1 Class2
#>     Class1    227     50
#>     Class2     31    192
#>                                           
#>                Accuracy : 0.838           
#>                  95% CI : (0.8027, 0.8692)
#>     No Information Rate : 0.516           
#>     P-Value [Acc > NIR] : <2e-16          
#>                                           
#>                   Kappa : 0.6749          
#>                                           
#>  Mcnemar's Test P-Value : 0.0455          
#>                                           
#>             Sensitivity : 0.8798          
#>             Specificity : 0.7934          
#>          Pos Pred Value : 0.8195          
#>          Neg Pred Value : 0.8610          
#>              Prevalence : 0.5160          
#>          Detection Rate : 0.4540          
#>    Detection Prevalence : 0.5540          
#>       Balanced Accuracy : 0.8366          
#>                                           
#>        'Positive' Class : Class1          
#> 

reprex package (v0.3.0.9001)

于 2020-10-21 创建

看起来您的变量名称将是 died.pred_class;您需要将包含预测的数据框保存为对象才能访问它。