Tidymodels:根据结果过滤工作流集

Tidymodels: Filter workflowsets based on results

是否有一种平滑的方法来过滤 workflowsets 对象?在我的例子中,我只想保留那些行是 mean,对于 roc_auc>= 0.8。我想我可以通过将 rank_results 函数与一些 joins 一起应用来得到这个结果,但也许有一种“更干净”的方法来做到这一点?

提前致谢! M.

library(titanic)
library(tidyverse)
library(tidymodels)
library(finetune)
library(themis)
#> Registered S3 methods overwritten by 'themis':
#>   method                  from   
#>   bake.step_downsample    recipes
#>   bake.step_upsample      recipes
#>   prep.step_downsample    recipes
#>   prep.step_upsample      recipes
#>   tidy.step_downsample    recipes
#>   tidy.step_upsample      recipes
#>   tunable.step_downsample recipes
#>   tunable.step_upsample   recipes
#> 
#> Attaching package: 'themis'
#> The following objects are masked from 'package:recipes':
#> 
#>     step_downsample, step_upsample
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness

options(tidymodels.dark = TRUE)

## Splitting Train / Test data

titanic_train <- as_tibble(titanic_train) %>% 
  mutate(Survived = factor(Survived),
         Pclass = factor(Pclass, ordered = TRUE),
         Sex = factor(Sex)) %>% 
  select(!c(Name, Ticket, Cabin, Embarked))

titanic_folds <- vfold_cv(titanic_train, v = 5, repeats = 5)

## Model Definition

rf_model <- rand_forest(mtry = tune(),
                        trees = 200,
                        min_n = tune()) %>%
  set_engine("ranger") %>%
  set_mode("classification")

xgb_model <- boost_tree(
  trees = 200,
  tree_depth = tune(),
  min_n = tune(),
  loss_reduction = tune(),
  sample_size = tune(),
  mtry = tune(),
  learn_rate = tune()
) %>%
  set_engine("xgboost") %>%
  set_mode("classification")

naive_bayes_model <- naive_Bayes() %>%
  set_mode("classification") %>%
  set_engine("naivebayes")

base_rec <-  recipe(Survived ~ ., data = titanic_train) %>%
  update_role(PassengerId, new_role = 'id') %>%
  step_impute_knn(all_predictors(), neighbors = 5) %>%
  step_dummy(all_nominal_predictors()) %>% 
  step_downsample(Survived, seed = 123)

another_rec <- recipe(Survived ~ ., data = titanic_train) %>%
  update_role(PassengerId, new_role = 'ID') %>%
  step_impute_knn(all_predictors(), neighbors = 5) %>%
  step_dummy(all_nominal_predictors()) %>% 
  step_normalize(all_numeric_predictors()) %>% 
  step_downsample(Survived, seed = 123) 

titanic_models <- workflow_set(
  preproc = list(
    base = base_rec,
    another = another_rec
  ),
  models = list(
    rf = rf_model,
    xgb = xgb_model,
    bayes = naive_bayes_model
  ),
  cross = TRUE
)

titanic_models
#> [38;5;246m# A workflow set/tibble: 6 x 4[39m
#>   wflow_id      info                 option    result    
#>  <chr>         <list>               <list>    <list>    
#> [base_rf       [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [base_xgb      [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [base_bayes    [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_rf    [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_xgb   [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>
#> [another_bayes [<tibble[,4] [1 × 4]>[39m [<opts[0]>[<list [0]>

num_cores <- parallel::detectCores() - 2

cl <- parallel::makeCluster(num_cores)
doParallel::registerDoParallel(cl = cl)

titanic_models_result <- titanic_models %>%
  workflow_map(
    "tune_race_anova",
    resamples = titanic_folds,
    grid = 4,
    metrics = metric_set(accuracy, roc_auc),
    verbose = TRUE,
    control = control_race(
      verbose = TRUE,
      save_pred = TRUE,
      save_workflow = TRUE
    )
  )
#> i 1 of 6 tuning:     base_rf
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 1 of 6 tuning:     base_rf (38.3s)
#> i 2 of 6 tuning:     base_xgb
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 2 of 6 tuning:     base_xgb (7s)
#> i    No tuning parameters. `fit_resamples()` will be attempted
#> i 3 of 6 resampling: base_bayes
#> ? 3 of 6 resampling: base_bayes (3.3s)
#> i 4 of 6 tuning:     another_rf
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 4 of 6 tuning:     another_rf (29.8s)
#> i 5 of 6 tuning:     another_xgb
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ? 5 of 6 tuning:     another_xgb (6s)
#> i    No tuning parameters. `fit_resamples()` will be attempted
#> i 6 of 6 resampling: another_bayes
#> ? 6 of 6 resampling: another_bayes (2.4s)

parallel::stopCluster(cl)

titanic_models_result
 wflow_id      info                 option    result   
  <chr>         <list>               <list>    <list>   
1 base_rf       <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
2 base_xgb      <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
3 base_bayes    <tibble[,4] [1 × 4]> <opts[4]> <rsmp[+]>
4 another_rf    <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
5 another_xgb   <tibble[,4] [1 × 4]> <opts[4]> <race[+]>
6 another_bayes <tibble[,4] [1 × 4]> <opts[4]> <rsmp[+]>

rank_results(titanic_models_result, rank_metric = "roc_auc")
# A tibble: 20 x 9
   wflow_id  .config    .metric  mean std_err     n preprocessor model 
   <chr>     <fct>      <chr>   <dbl>   <dbl> <int> <chr>        <chr> 
 1 base_rf   Preproces? accura? 0.816 0.00579    25 recipe       rand_...
 2 base_rf   Preproces? roc_auc 0.870 0.00626    25 recipe       rand_...
 3 another_? Preproces? accura? 0.816 0.00569    25 recipe       rand_...
 4 another_? Preproces? roc_auc 0.870 0.00628    25 recipe       rand_...
 5 another_? Preproces? accura? 0.816 0.00611    25 recipe       rand_...
 6 another_? Preproces? roc_auc 0.869 0.00633    25 recipe       rand_...
 7 another_? Preproces? accura? 0.817 0.00461    25 recipe       rand_...
 8 another_? Preproces? roc_auc 0.869 0.00619    25 recipe       rand_...
 9 base_rf   Preproces? accura? 0.816 0.00610    25 recipe       rand_...
10 base_rf   Preproces? roc_auc 0.869 0.00625    25 recipe       rand_...
11 base_rf   Preproces? accura? 0.817 0.00470    25 recipe       rand_...
12 base_rf   Preproces? roc_auc 0.869 0.00614    25 recipe       rand_...
13 base_xgb  Preproces? accura? 0.766 0.00737    25 recipe       boost...
14 base_xgb  Preproces? roc_auc 0.836 0.00646    25 recipe       boost...
15 another_? Preproces? accura? 0.766 0.00737    25 recipe       boost...
16 another_? Preproces? roc_auc 0.836 0.00646    25 recipe       boost...
17 base_bay? Preproces? accura? 0.757 0.00692    25 recipe       naive...
18 base_bay? Preproces? roc_auc 0.812 0.00640    25 recipe       naive...
19 another_? Preproces? accura? 0.756 0.00704    25 recipe       naive...
20 another_? Preproces? roc_auc 0.811 0.00636    25 recipe       naive...
# ... with 1 more variable: rank <int>

reprex package (v2.0.0)

于 2021-10-13 创建

有几种方法可以解决这个问题,但使用 rank_results() 肯定是一个不错的方法。

  • 您可以使用 collect_metrics(),然后 filter() 找到适合您条件的工作流。

  • 不要忘记您可以 extract_*() 工作流集的各种组件。

查看 what's currently implemented in workflowsets 看看是否还有其他适合您的特定用例。