Tidymodels 包:使用 ggplot() 可视化袋装树以显示最重要的预测变量

Tidymodels Package: Visualising Bagged Trees using ggplot() to show the most important predictors

概览:

我正在按照教程(见下文)从袋装树、随机森林、增强树和一般线性模型中找到最合适的模型。

教程(参见下面的示例)

https://bcullen.rbind.io/post/2020-06-02-tidymodels-decision-tree-learning-in-r/

问题

在这种情况下,我想进一步探索数据并可视化我的数据最重要的预测变量(见下图)。

我的数据框称为 FID,袋装树模型中的预测变量涉及:

  1. 年份(数字)
  2. 月份(系数)
  3. 天(数字)

因变量是频率(数值)

当我尝试 运行 绘图以可视化最重要的预测变量时,我不断收到此错误消息:-

错误信息

Error: Can't subset columns that don't exist.
x Column `.extracts` doesn't exist.
Run `rlang::last_error()` to see where the error occurred.
Called from: rlang:::signal_abort(x)

如果有人对如何修复错误消息有任何建议,我将不胜感激。

非常感谢

如何从教程中的 R 代码生成绘图的示例

可视化模型

绘制以显示最重要的预测变量

我的 R 代码

###########################################################
#split this single dataset into two: a training set and a testing set
data_split <- initial_split(FID)
# Create data frames for the two sets:
train_data <- training(data_split)
test_data  <- testing(data_split)

 # resample the data with 10-fold cross-validation (10-fold by default)
  cv <- vfold_cv(train_data)
###########################################################

##Produce the recipe

rec <- recipe(Frequency_Blue ~ ., data = FID) %>% 
          step_nzv(all_predictors(), freq_cut = 0, unique_cut = 0) %>% # remove variables with zero variances
          step_novel(all_nominal()) %>% # prepares test data to handle previously unseen factor levels 
          step_medianimpute(all_numeric(), -all_outcomes(), -has_role("id vars"))  %>% # replaces missing numeric observations with the median
          step_dummy(all_nominal(), -has_role("id vars")) # dummy codes categorical variables

###################################################################################

#####Fit the Bagged Tree Model
mod_bag <- bag_tree() %>%
            set_mode("regression") %>%
             set_engine("rpart", times = 10) #10 bootstrap resamples
                

##Create workflow
wflow_bag <- workflow() %>% 
                   add_recipe(rec) %>%
                       add_model(mod_bag)

##Fit the model
plan(multisession)

fit_bag <- fit_resamples(
                      wflow_bag,
                      cv,
                      metrics = metric_set(rmse, rsq),
                      control = control_resamples(save_pred = TRUE)
                      )

##########################################################
##Visualise the model

##Open a plotting window
dev.new()

# extract roots
bag_roots <-  function(x){
                      x %>% 
                      dplyr::select(.extracts) %>% 
                      unnest(cols = c(.extracts)) %>% 
                      dplyr::mutate(models = map(.extracts,
                      ~.x$FID)) %>% 
                      dplyr::select(-.extracts) %>% 
                      unnest(cols = c(fit_bag)) %>% 
                      mutate(root = map_chr(fit_bag,
                      ~as.character(.x$fit$frame[1, 1]))) %>%
                      dplyr::select(root)  
              }


# plot the bagged tree model
  bag_roots(fit_bag) %>% 
          ggplot(mapping = aes(x = fct_rev(fct_infreq(root)))) + 
          geom_bar() + 
          coord_flip() + 
          labs(x = "root", y = "count")

 #Error Message

  Error: Can't subset columns that don't exist.
  x Column `.extracts` doesn't exist.
  Run `rlang::last_error()` to see where the error occurred.
  Called from: rlang:::signal_abort(x)
    

数据框 - FID

  structure(list(Year = c(2015, 2015, 2015, 2015, 2015, 2015, 2015, 
2015, 2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016, 2016, 2016, 
2016, 2016, 2016, 2016, 2016, 2016, 2017, 2017, 2017, 2017, 2017, 
2017, 2017, 2017, 2017, 2017, 2017, 2017), Month = structure(c(1L, 
2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 1L, 2L, 3L, 4L, 
5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 
8L, 9L, 10L, 11L, 12L), .Label = c("January", "February", "March", 
"April", "May", "June", "July", "August", "September", "October", 
"November", "December"), class = "factor"), Frequency = c(36, 
28, 39, 46, 5, 0, 0, 22, 10, 15, 8, 33, 33, 29, 31, 23, 8, 9, 
7, 40, 41, 41, 30, 30, 44, 37, 41, 42, 20, 0, 7, 27, 35, 27, 
43, 38), Days = c(31, 28, 31, 30, 6, 0, 0, 29, 15, 
29, 29, 31, 31, 29, 30, 30, 7, 0, 7, 30, 30, 31, 30, 27, 31, 
28, 30, 30, 21, 0, 7, 26, 29, 27, 29, 29)), row.names = c(NA, 
-36L), class = "data.frame")

这里有几处需要调整:

  • 确保 extractfit_resamples()
  • 期间满足您的需要
  • 为您在 bag_roots() 函数中创建的 您的 数据使用正确的变量名称。

结果会是这样:

library(tidymodels)
library(baguette)

FID <- structure(list(Year = c(2015, 2015, 2015, 2015, 2015, 2015, 2015, 
                               2015, 2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016, 2016, 2016, 
                               2016, 2016, 2016, 2016, 2016, 2016, 2017, 2017, 2017, 2017, 2017, 
                               2017, 2017, 2017, 2017, 2017, 2017, 2017), 
                      Month = structure(c(1L, 
                                          2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 1L, 2L, 3L, 4L, 
                                          5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 
                                          8L, 9L, 10L, 11L, 12L), .Label = c("January", "February", "March", 
                                                                             "April", "May", "June", "July", "August", "September", "October", 
                                                                             "November", "December"), class = "factor"), 
                      Frequency = c(36, 
                                    28, 39, 46, 5, 0, 0, 22, 10, 15, 8, 33, 33, 29, 31, 23, 8, 9, 
                                    7, 40, 41, 41, 30, 30, 44, 37, 41, 42, 20, 0, 7, 27, 35, 27, 
                                    43, 38), Days = c(31, 28, 31, 30, 6, 0, 0, 29, 15, 
                                                      29, 29, 31, 31, 29, 30, 30, 7, 0, 7, 30, 30, 31, 30, 27, 31, 
                                                      28, 30, 30, 21, 0, 7, 26, 29, 27, 29, 29)), row.names = c(NA, 
                                                                                                                -36L), 
                 class = "data.frame")

data_split <- initial_split(FID)
train_data <- training(data_split)
test_data  <- testing(data_split)
cv <- vfold_cv(train_data, v = 3)

rec <- recipe(Frequency ~ ., data = FID) %>% 
  step_nzv(all_predictors(), freq_cut = 0, unique_cut = 0) %>% # remove variables with zero variances
  step_novel(all_nominal()) %>% # prepares test data to handle previously unseen factor levels 
  step_medianimpute(all_numeric(), -all_outcomes(), -has_role("id vars"))  %>% # replaces missing numeric observations with the median
  step_dummy(all_nominal()) # dummy codes categorical variables


mod_bag <- bag_tree() %>%
  set_mode("regression") %>%
  set_engine("rpart", times = 10) #10 bootstrap resamples


wflow_bag <- workflow() %>% 
  add_recipe(rec) %>%
  add_model(mod_bag)

fit_bag <- fit_resamples(
  wflow_bag,
  cv,
  metrics = metric_set(rmse, rsq),
  control = control_resamples(save_pred = TRUE,
                              extract = function(x) extract_model(x))
)
#> 
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#> 
#>     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
#>     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
#>     splice
#> 
#> Attaching package: 'vctrs'
#> The following object is masked from 'package:tibble':
#> 
#>     data_frame
#> The following object is masked from 'package:dplyr':
#> 
#>     data_frame
#> 
#> Attaching package: 'rpart'
#> The following object is masked from 'package:dials':
#> 
#>     prune

bag_roots <-  function(x){
  x %>% 
    dplyr::select(.extracts) %>% 
    unnest(cols = c(.extracts)) %>% 
    dplyr::mutate(models = map(.extracts,
                               ~.x$model_df)) %>% 
    dplyr::select(-.extracts) %>% 
    unnest(cols = c(models)) %>% 
    mutate(root = map_chr(model,
                          ~as.character(.x$fit$frame[1, 1]))) %>%
    dplyr::select(root)  
}


# plot the bagged tree model
library(forcats)
bag_roots(fit_bag) %>% 
  ggplot(mapping = aes(x = fct_rev(fct_infreq(root)))) + 
  geom_bar() + 
  coord_flip() + 
  labs(x = "root", y = "count")

reprex package (v0.3.0.9001)

于 2020-11-20 创建

不是特别令人兴奋,但希望您真实的、更大的数据集显示出更有趣的结果!