在 tidymodel 的 collect_predictions() 上使用 caret::confusionMatrix() 进行模型评估时出错
Getting error on using caret::confusionMatrix() on collect_predictions() from tidymodel for model evaluation
我是 R 的新手,使用 tidymodels
创建了一个 classification
模型,下面是 collect_predictions(model)
的结果
collect_predictions(members_final) %>% print()
# A tibble: 19,126 x 6
id .pred_died .pred_survived .row .pred_class died
<chr> <dbl> <dbl> <int> <fct> <fct>
1 train/test split 0.285 0.715 5 survived survived
2 train/test split 0.269 0.731 6 survived survived
3 train/test split 0.298 0.702 7 survived survived
4 train/test split 0.276 0.724 8 survived survived
5 train/test split 0.251 0.749 10 survived survived
6 train/test split 0.124 0.876 18 survived survived
7 train/test split 0.127 0.873 21 survived survived
8 train/test split 0.171 0.829 26 survived survived
9 train/test split 0.158 0.842 30 survived survived
10 train/test split 0.150 0.850 32 survived survived
# … with 19,116 more rows
它适用于 yardstick
函数:
collect_predictions(members_final) %>%
conf_mat(died, .pred_class)
Truth
Prediction died survived
died 196 7207
survived 90 11633
但是当我将 collect_predictions
传送到 caret::confusionMatrix()
时它不起作用
collect_predictions(members_final) %>%
caret::confusionMatrix(as.factor(died), as.factor(.pred_class))
############## output #################
Error: `data` and `reference` should be factors with the same levels.
Traceback:
1. collect_predictions(members_final) %>% caret::confusionMatrix(as.factor(died),
. as.factor(.pred_class))
2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
3. eval(quote(`_fseq`(`_lhs`)), env, env)
4. eval(quote(`_fseq`(`_lhs`)), env, env)
我不确定这里出了什么问题,所以如何修复它以使用插入符号评估?
使用插入符计算的目的是找出 positive/negative class.
是否有任何其他方法可以找出 positive/neg classes (levels(df$class) 是否正确找出正 classes 用于模型?)
如果您有预测,例如 collect_predictions()
的输出,那么您不想将其通过管道传输到插入符号的函数中。它不像 yardstick 函数那样将数据作为第一个参数。相反,将参数作为向量传递:
library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
data("two_class_example", package = "yardstick")
confusionMatrix(two_class_example$predicted, two_class_example$truth)
#> Confusion Matrix and Statistics
#>
#> Reference
#> Prediction Class1 Class2
#> Class1 227 50
#> Class2 31 192
#>
#> Accuracy : 0.838
#> 95% CI : (0.8027, 0.8692)
#> No Information Rate : 0.516
#> P-Value [Acc > NIR] : <2e-16
#>
#> Kappa : 0.6749
#>
#> Mcnemar's Test P-Value : 0.0455
#>
#> Sensitivity : 0.8798
#> Specificity : 0.7934
#> Pos Pred Value : 0.8195
#> Neg Pred Value : 0.8610
#> Prevalence : 0.5160
#> Detection Rate : 0.4540
#> Detection Prevalence : 0.5540
#> Balanced Accuracy : 0.8366
#>
#> 'Positive' Class : Class1
#>
由 reprex package (v0.3.0.9001)
于 2020-10-21 创建
看起来您的变量名称将是 died
和 .pred_class
;您需要将包含预测的数据框保存为对象才能访问它。
我是 R 的新手,使用 tidymodels
创建了一个 classification
模型,下面是 collect_predictions(model)
collect_predictions(members_final) %>% print()
# A tibble: 19,126 x 6
id .pred_died .pred_survived .row .pred_class died
<chr> <dbl> <dbl> <int> <fct> <fct>
1 train/test split 0.285 0.715 5 survived survived
2 train/test split 0.269 0.731 6 survived survived
3 train/test split 0.298 0.702 7 survived survived
4 train/test split 0.276 0.724 8 survived survived
5 train/test split 0.251 0.749 10 survived survived
6 train/test split 0.124 0.876 18 survived survived
7 train/test split 0.127 0.873 21 survived survived
8 train/test split 0.171 0.829 26 survived survived
9 train/test split 0.158 0.842 30 survived survived
10 train/test split 0.150 0.850 32 survived survived
# … with 19,116 more rows
它适用于 yardstick
函数:
collect_predictions(members_final) %>%
conf_mat(died, .pred_class)
Truth
Prediction died survived
died 196 7207
survived 90 11633
但是当我将 collect_predictions
传送到 caret::confusionMatrix()
时它不起作用
collect_predictions(members_final) %>%
caret::confusionMatrix(as.factor(died), as.factor(.pred_class))
############## output #################
Error: `data` and `reference` should be factors with the same levels.
Traceback:
1. collect_predictions(members_final) %>% caret::confusionMatrix(as.factor(died),
. as.factor(.pred_class))
2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
3. eval(quote(`_fseq`(`_lhs`)), env, env)
4. eval(quote(`_fseq`(`_lhs`)), env, env)
我不确定这里出了什么问题,所以如何修复它以使用插入符号评估?
使用插入符计算的目的是找出 positive/negative class.
是否有任何其他方法可以找出 positive/neg classes (levels(df$class) 是否正确找出正 classes 用于模型?)
如果您有预测,例如 collect_predictions()
的输出,那么您不想将其通过管道传输到插入符号的函数中。它不像 yardstick 函数那样将数据作为第一个参数。相反,将参数作为向量传递:
library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
data("two_class_example", package = "yardstick")
confusionMatrix(two_class_example$predicted, two_class_example$truth)
#> Confusion Matrix and Statistics
#>
#> Reference
#> Prediction Class1 Class2
#> Class1 227 50
#> Class2 31 192
#>
#> Accuracy : 0.838
#> 95% CI : (0.8027, 0.8692)
#> No Information Rate : 0.516
#> P-Value [Acc > NIR] : <2e-16
#>
#> Kappa : 0.6749
#>
#> Mcnemar's Test P-Value : 0.0455
#>
#> Sensitivity : 0.8798
#> Specificity : 0.7934
#> Pos Pred Value : 0.8195
#> Neg Pred Value : 0.8610
#> Prevalence : 0.5160
#> Detection Rate : 0.4540
#> Detection Prevalence : 0.5540
#> Balanced Accuracy : 0.8366
#>
#> 'Positive' Class : Class1
#>
由 reprex package (v0.3.0.9001)
于 2020-10-21 创建看起来您的变量名称将是 died
和 .pred_class
;您需要将包含预测的数据框保存为对象才能访问它。