Tidymodels:仅当概率为 75% 或更高时才分类为 TRUE

Tidymodels: Classify as TRUE only if the probability is 75% or higher

我有一个二元分类问题,使用了随机森林和逻辑回归。 根据 conf_matcollect_metrics()collect_predictions 的结果,我想更改我的模型,仅当模型“确定”为 75% 或更高概率时才分类为 TRUE。我只是不知道在哪里指定此更改。如果有人能给我提示,那就太好了。我的直觉告诉我它应该在模型规范中的某个地方,例如在这里的某个地方,但也许我错了。

canc_rf_model <- rand_forest(
    mtry = tune(),
    min_n = tune(),
    trees = 500) %>%
  set_engine("ranger") %>%
  set_mode("classification")

canc_log_model <- logistic_reg() %>% 
  set_engine("glm") %>% 
  set_mode("classification")

非常感谢您! M.

硬性 class 预测来自基础 ranger::predictions() 函数,而不是来自 函数,因此拟合本身没有太多工作要做。

不过,如果你喜欢的话,你可以在试穿后很流畅地改变它。举个例子class化模型:

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

data("ad_data")
alz <- ad_data

# data splitting
set.seed(100)
alz_split  <- initial_split(alz, strata = Class, prop = .9)
alz_train  <- training(alz_split)
alz_test   <- testing(alz_split)

# data resampling
set.seed(100)
alz_folds <- 
    vfold_cv(alz_train, v = 10, strata = Class)

rf_mod <-
    rand_forest(trees = 1e3) %>% 
    set_engine("ranger") %>% 
    set_mode("classification")

rf_wf <-
    workflow() %>% 
    add_formula(Class ~ .) %>% 
    add_model(rf_mod)

set.seed(100)
rf_preds <- rf_wf %>% 
    fit_resamples(
        resamples = alz_folds, 
        control = control_resamples(save_pred = TRUE)) %>% 
    collect_predictions()

这是默认的混淆矩阵:

rf_preds %>%
    conf_mat(Class, .pred_class)
#>           Truth
#> Prediction Impaired Control
#>   Impaired       37       5
#>   Control        45     213

您可以使用 probably 程序包来 post 处理您的 class 概率估计并仅覆盖默认值:

library(probably)
#> 
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#> 
#>     as.factor, as.ordered

rf_preds %>%
    mutate(.pred_class = make_two_class_pred(.pred_Impaired, 
                                             levels(rf_preds$Class),
                                             threshold = 0.75),
           .pred_class = factor(.pred_class, levels = levels(rf_preds$Class))) %>%
    conf_mat(Class, .pred_class)
#>           Truth
#> Prediction Impaired Control
#>   Impaired        0       0
#>   Control        82     218

reprex package (v1.0.0)

于 2021 年 3 月 23 日创建