从 Tidymodels 中经过训练的 C5.0 模型中提取规则

Extract Rules from Trained C5.0 Model in Tidymodels

我可以而且应该做一个更简单的 reprex,但这确实是我的工作成果。 在 Tidymodels 框架中训练 C5.0 模型后,如何“查看”模型生成的规则? 我试图复制此处说明的内容

https://www.tidyverse.org/blog/2020/05/rules-0-0-1/

但我并没有走得太远(但我确定解决方案必须是一条线)。

非常感谢!

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 0.1.2 ──
#> ✔ broom     0.7.2          ✔ recipes   0.1.15    
#> ✔ dials     0.0.9          ✔ rsample   0.0.8     
#> ✔ dplyr     1.0.2          ✔ tibble    3.0.4     
#> ✔ ggplot2   3.3.2          ✔ tidyr     1.1.2     
#> ✔ infer     0.5.3          ✔ tune      0.1.2.9000
#> ✔ modeldata 0.1.0          ✔ workflows 0.2.1     
#> ✔ parsnip   0.1.4.9000     ✔ yardstick 0.0.7     
#> ✔ purrr     0.3.4
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter()  masks stats::filter()
#> ✖ dplyr::lag()     masks stats::lag()
#> ✖ recipes::step()  masks stats::step()
library(rules)
#> 
#> Attaching package: 'rules'
#> The following object is masked from 'package:dials':
#> 
#>     max_rules


df_ini <- structure(list(year = c(2002, 2004, 2005, 2006, 2007, 2008, 2009, 
2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019), 
    berd = c(3130.884, 3556.479, 4207.669, 4448.676, 4845.861, 
    5232.63, 5092.902, 5520.422, 5692.841, 6540.457, 6778.42, 
    7324.679, 7498.488, 7824.51, 7888.444, 8461.72, 8865.96), 
    gbaord = c(1537.89, 1537.89, 1619.74, 1697.55, 1770.144, 
    1986.775, 2149.787, 2269.986, 2428.143, 2452.955, 2587.586, 
    2647.489, 2744.844, 2875.706, 2889.779, 2913.369, 3081.087
    ), employment_be = c(2775.22, 2731.08, 2709.59, 2708.39, 
    2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 
    2725.66, 2735.69, 2750.52, 2782.9, 2852.33, 2890.9), employment_c = c(2562.53, 
    2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 
    2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 
    2622.5, 2656.89), employment_j = c(400.75, 387.53, 384.64, 
    389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 419.75, 427.59, 
    438.96, 440.33, 460.84, 473.4, 494.4, 513.62), employment_k = c(502.42, 
    504.63, 515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 
    534.98, 524.13, 518.89, 511.57, 505.32, 496.41, 495.4, 495.98
    ), employment_mn = c(1248.01, 1365.29, 1425.81, 1537.88, 
    1622.95, 1727.76, 1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 
    1950.02, 1968.83, 2021.51, 2109.71, 2189.27, 2225), employment_oq = c(3241.36, 
    3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23, 
    3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72, 
    4238.87, 4284.27), employment_total = c(15113.52, 15307.28, 
    15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 16392.87, 
    16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 17365.32, 
    17650.21, 17951.61, 18156.52), value_be = c(47967.1, 50737.6, 
    52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 61443, 
    63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 72698.8, 75792.3, 
    77284), value_c = c(40192.9, 42014.6, 44229, 47735.5, 51552.4, 
    51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
    57458.7, 60962.8, 62196, 65063.5, 66063.6), value_j = c(7737.1, 
    7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 
    9405.1, 9802.1, 10361.4, 10695.4, 11455.3, 11720.6, 12871, 
    13540.3), value_k = c(10225.2, 10541.9, 11005.3, 11912.3, 
    13102.7, 13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 
    12962.4, 13482.9, 13236.4, 13744.1, 14152.6, 14739.1), value_mn = c(15074, 
    16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
    24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 
    33781.9, 35152.9), value_oq = c(35065.6, 37329.6, 38288.8, 
    40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5, 
    50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9, 61680.1
    ), value_total = c(202353.5, 216098.3, 225888.1, 239076, 
    253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
    297230.1, 307037.7, 318952.7, 329396.1, 344338.6, 355359.1
    ), gdp_b1gq = c(226735.3, 242348.3, 254075, 267824.4, 283978, 
    293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 
    333146.1, 344269.3, 357608, 369341.3, 385361.9, 397575.4), 
    gdp_p3 = c(164107.8, 176316.4, 185871.1, 194102, 200944.4, 
    208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 
    243860.6, 249404.3, 257166.5, 265900.2, 274583.7, 282863.3
    ), gdp_p61 = c(74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 
    113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
    126109.3, 129183.6, 131524, 140057.8, 150278.2, 152545.2), 
    gdp_p62 = c(28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 
    38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
    55885.5, 59584.7, 64333.5, 68409.7), turnover_manu_dom = c(80, 
    87, 93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 
    107.1, 104.7, 102.9, 107.9, 107.9, 107.9), turnover_manu_non_dom = c(70.9, 
    81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7, 
    112.8, 114.9, 118.2, 120.1, 129.2, 129.2, 129.2), turnover_manu_tot = c(74.7, 
    84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 111.9, 
    111.7, 112.6, 112.9, 120.3, 120.3, 120.3), price_index = c(1.7, 
    2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 
    1, 2.2, 2.1, 1.5), capital_n1132g = c(3638.4, 3633.3, 3616.2, 
    3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 
    3467.9, 4214.2, 4237.4, 4450.2, 4598.6, 4721), capital_n117g = c(8369.6, 
    8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
    13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 
    19642.4, 20713.1), capital_n11mg = c(18749.6, 19433.5, 20051.6, 
    20569.8, 22646.1, 23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 
    24057.7, 23832.8, 25019.2, 27608.2, 29790.1, 30998, 32856.8
    ), lagged_gbaord = c(1537.89, 1537.89, 1537.89, 1619.74, 
    1697.55, 1770.144, 1986.775, 2149.787, 2269.986, 2428.143, 
    2452.955, 2587.586, 2647.489, 2744.844, 2875.706, 2889.779, 
    2913.369), lagged_employment_be = c(2775.22, 2775.22, 2731.08, 
    2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 
    2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9, 2852.33
    ), lagged_employment_c = c(2562.53, 2562.53, 2518.57, 2496.98, 
    2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 2507.41, 
    2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 2622.5), lagged_employment_j = c(400.75, 
    400.75, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 
    410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4, 494.4
    ), lagged_employment_k = c(502.42, 502.42, 504.63, 515.39, 
    523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 524.13, 
    518.89, 511.57, 505.32, 496.41, 495.4), lagged_employment_mn = c(1248.01, 
    1248.01, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 
    1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51, 
    2109.71, 2189.27), lagged_employment_oq = c(3241.36, 3241.36, 
    3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23, 
    3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72, 
    4238.87), lagged_employment_total = c(15113.52, 15113.52, 
    15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 
    16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 
    17365.32, 17650.21, 17951.61), lagged_value_be = c(47967.1, 
    47967.1, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 
    58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 
    72698.8, 75792.3), lagged_value_c = c(40192.9, 40192.9, 42014.6, 
    44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 51467.7, 
    53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196, 65063.5
    ), lagged_value_j = c(7737.1, 7737.1, 7756.1, 8134.2, 8378.8, 
    8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 10361.4, 
    10695.4, 11455.3, 11720.6, 12871), lagged_value_k = c(10225.2, 
    10225.2, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 12123.9, 
    12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 13236.4, 
    13744.1, 14152.6), lagged_value_mn = c(15074, 15074, 16569.1, 
    18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 24895.4, 
    25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 33781.9
    ), lagged_value_oq = c(35065.6, 35065.6, 37329.6, 38288.8, 
    40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5, 
    50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9), lagged_value_total = c(202353.5, 
    202353.5, 216098.3, 225888.1, 239076, 253604.6, 262414.7, 
    256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 307037.7, 
    318952.7, 329396.1, 344338.6), lagged_gdp_b1gq = c(226735.3, 
    226735.3, 242348.3, 254075, 267824.4, 283978, 293761.9, 288044.1, 
    295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 344269.3, 
    357608, 369341.3, 385361.9), lagged_gdp_p3 = c(164107.8, 
    164107.8, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
    213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
    249404.3, 257166.5, 265900.2, 274583.7), lagged_gdp_p61 = c(74691.6, 
    74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
    91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
    129183.6, 131524, 140057.8, 150278.2), lagged_gdp_p62 = c(28063.4, 
    28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
    39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
    59584.7, 64333.5), lagged_turnover_manu_dom = c(80, 80, 87, 
    93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 107.1, 
    104.7, 102.9, 107.9, 107.9), lagged_turnover_manu_non_dom = c(70.9, 
    70.9, 81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7, 
    112.8, 114.9, 118.2, 120.1, 129.2, 129.2), lagged_turnover_manu_tot = c(74.7, 
    74.7, 84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 
    111.9, 111.7, 112.6, 112.9, 120.3, 120.3), lagged_price_index = c(1.7, 
    1.7, 2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 
    0.8, 1, 2.2, 2.1), lagged_capital_n1132g = c(3638.4, 3638.4, 
    3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 
    4005.6, 3718.6, 3467.9, 4214.2, 4237.4, 4450.2, 4598.6), 
    lagged_capital_n117g = c(8369.6, 8369.6, 8679.9, 8938.9, 
    9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 13465.4, 13927.5, 
    15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 19642.4), lagged_capital_n11mg = c(18749.6, 
    18749.6, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
    20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
    29790.1, 30998), country = c("AT", "AT", "AT", "AT", "AT", 
    "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", 
    "AT", "AT")), row.names = c(NA, -17L), class = c("tbl_df", 
"tbl", "data.frame"))




set.seed(1234)

nn <- nrow(df_ini)

time_back <- 1

indices <-
  list(analysis   =  1:(nn-time_back) , 
       assessment = (nn-time_back+1):nn
       )

df_split <- make_splits(indices, df_ini)



df_train <- training(df_split)
df_test <- testing(df_split)

folded_data <- vfold_cv(df_train,3)

cubist_recipe <- 
  recipe(formula = berd ~ ., data = df_train) %>% 
    ## step_string2factor(one_of("country")) %>%
   update_role(year, new_role = "ID") %>%
   step_zv(all_predictors()) 

cubist_spec <- 
  cubist_rules(committees = tune(), neighbors = tune()) %>% 
  set_engine("Cubist") 

cubist_workflow <- 
  workflow() %>% 
  add_recipe(cubist_recipe) %>% 
  add_model(cubist_spec) 

cubist_grid <- tidyr::crossing(committees = c(1:9, (1:5) * 10),
                                   neighbors = c(0, 3, 6, 9)) 

cubist_tune <- 
  tune_grid(cubist_workflow, resamples = folded_data, grid = cubist_grid) 
#> 
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#> 
#>     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
#>     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
#>     splice
#> 
#> Attaching package: 'vctrs'
#> The following object is masked from 'package:tibble':
#> 
#>     data_frame
#> The following object is masked from 'package:dplyr':
#> 
#>     data_frame
#> Loading required package: lattice


best_cub <- select_best(cubist_tune, "rmse")


final_cub <- finalize_workflow(
  cubist_workflow,
  best_cub
)


final_cub
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: cubist_rules()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> ● step_zv()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Cubist Model Specification (regression)
#> 
#> Main Arguments:
#>   committees = 1
#>   neighbors = 3
#> 
#> Computational engine: Cubist
   
fit_model <- final_cub %>%
    fit(df_train)

fit_model
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: cubist_rules()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> ● step_zv()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> 
#> Call:
#> cubist.default(x = x, y = y, committees = 1)
#> 
#> Number of samples: 16 
#> Number of predictors: 52 
#> 
#> Number of committees: 1 
#> Number of rules: 1

 ### at this point how to see the rules in the model trained on the data ???

reprex package (v0.3.0)

于 2020 年 12 月 10 日创建

诚然,tidymodels 提供的当前解决方案并不是很理想。我相信目前在模型中找出规则的最好方法是提取底层拟合对象,它位于工作流的深处,然后调用 summary()。你想做的事:summary(fit_model$fit$fit$fit).

library(tidymodels)
library(rules)
#> 
#> Attaching package: 'rules'
#> The following object is masked from 'package:dials':
#> 
#>     max_rules

df_ini <- structure(list(year = c(2002, 2004, 2005, 2006, 2007, 2008, 2009, 
                                  2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019), 
                         berd = c(3130.884, 3556.479, 4207.669, 4448.676, 4845.861, 
                                  5232.63, 5092.902, 5520.422, 5692.841, 6540.457, 6778.42, 
                                  7324.679, 7498.488, 7824.51, 7888.444, 8461.72, 8865.96), 
                         gbaord = c(1537.89, 1537.89, 1619.74, 1697.55, 1770.144, 
                                    1986.775, 2149.787, 2269.986, 2428.143, 2452.955, 2587.586, 
                                    2647.489, 2744.844, 2875.706, 2889.779, 2913.369, 3081.087
                         ), employment_be = c(2775.22, 2731.08, 2709.59, 2708.39, 
                                              2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 
                                              2725.66, 2735.69, 2750.52, 2782.9, 2852.33, 2890.9), 
                         employment_c = c(2562.53, 
                                          2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 
                                          2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 
                                          2622.5, 2656.89), 
                         employment_j = c(400.75, 387.53, 384.64, 
                                          389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 419.75, 427.59, 
                                          438.96, 440.33, 460.84, 473.4, 494.4, 513.62), 
                         employment_k = c(502.42, 
                                          504.63, 515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 
                                          534.98, 524.13, 518.89, 511.57, 505.32, 496.41, 495.4, 495.98
                         ), 
                         employment_mn = c(1248.01, 1365.29, 1425.81, 1537.88, 
                                           1622.95, 1727.76, 1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 
                                           1950.02, 1968.83, 2021.51, 2109.71, 2189.27, 2225), 
                         employment_oq = c(3241.36, 
                                           3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23, 
                                           3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72, 
                                           4238.87, 4284.27), 
                         employment_total = c(15113.52, 15307.28, 
                                              15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 16392.87, 
                                              16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 17365.32, 
                                              17650.21, 17951.61, 18156.52), 
                         value_be = c(47967.1, 50737.6, 
                                      52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 61443, 
                                      63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 72698.8, 75792.3, 
                                      77284), 
                         value_c = c(40192.9, 42014.6, 44229, 47735.5, 51552.4, 
                                     51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
                                     57458.7, 60962.8, 62196, 65063.5, 66063.6), 
                         value_j = c(7737.1, 
                                     7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 
                                     9405.1, 9802.1, 10361.4, 10695.4, 11455.3, 11720.6, 12871, 
                                     13540.3), 
                         value_k = c(10225.2, 10541.9, 11005.3, 11912.3, 
                                     13102.7, 13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 
                                     12962.4, 13482.9, 13236.4, 13744.1, 14152.6, 14739.1), 
                         value_mn = c(15074, 
                                      16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
                                      24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 
                                      33781.9, 35152.9), 
                         value_oq = c(35065.6, 37329.6, 38288.8, 
                                      40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5, 
                                      50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9, 61680.1
                         ), 
                         value_total = c(202353.5, 216098.3, 225888.1, 239076, 
                                         253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
                                         297230.1, 307037.7, 318952.7, 329396.1, 344338.6, 355359.1
                         ), 
                         gdp_b1gq = c(226735.3, 242348.3, 254075, 267824.4, 283978, 
                                      293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 
                                      333146.1, 344269.3, 357608, 369341.3, 385361.9, 397575.4), 
                         gdp_p3 = c(164107.8, 176316.4, 185871.1, 194102, 200944.4, 
                                    208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 
                                    243860.6, 249404.3, 257166.5, 265900.2, 274583.7, 282863.3
                         ), gdp_p61 = c(74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 
                                        113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
                                        126109.3, 129183.6, 131524, 140057.8, 150278.2, 152545.2), 
                         gdp_p62 = c(28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 
                                     38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
                                     55885.5, 59584.7, 64333.5, 68409.7), 
                         turnover_manu_dom = c(80, 
                                               87, 93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 
                                               107.1, 104.7, 102.9, 107.9, 107.9, 107.9), 
                         turnover_manu_non_dom = c(70.9, 
                                                   81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7, 
                                                   112.8, 114.9, 118.2, 120.1, 129.2, 129.2, 129.2), 
                         turnover_manu_tot = c(74.7, 
                                               84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 111.9, 
                                               111.7, 112.6, 112.9, 120.3, 120.3, 120.3), 
                         price_index = c(1.7, 
                                         2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 
                                         1, 2.2, 2.1, 1.5), 
                         capital_n1132g = c(3638.4, 3633.3, 3616.2, 
                                            3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 
                                            3467.9, 4214.2, 4237.4, 4450.2, 4598.6, 4721), 
                         capital_n117g = c(8369.6, 
                                           8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
                                           13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 
                                           19642.4, 20713.1), capital_n11mg = c(18749.6, 19433.5, 20051.6, 
                                                                                20569.8, 22646.1, 23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 
                                                                                24057.7, 23832.8, 25019.2, 27608.2, 29790.1, 30998, 32856.8
                                           ), 
                         lagged_gbaord = c(1537.89, 1537.89, 1537.89, 1619.74, 
                                           1697.55, 1770.144, 1986.775, 2149.787, 2269.986, 2428.143, 
                                           2452.955, 2587.586, 2647.489, 2744.844, 2875.706, 2889.779, 
                                           2913.369), 
                         lagged_employment_be = c(2775.22, 2775.22, 2731.08, 
                                                  2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 
                                                  2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9, 2852.33
                         ), lagged_employment_c = c(2562.53, 2562.53, 2518.57, 2496.98, 
                                                    2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 2507.41, 
                                                    2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 2622.5), 
                         lagged_employment_j = c(400.75, 
                                                 400.75, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 
                                                 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4, 494.4
                         ), 
                         lagged_employment_k = c(502.42, 502.42, 504.63, 515.39, 
                                                 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 524.13, 
                                                 518.89, 511.57, 505.32, 496.41, 495.4), 
                         lagged_employment_mn = c(1248.01, 
                                                  1248.01, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 
                                                  1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51, 
                                                  2109.71, 2189.27), 
                         lagged_employment_oq = c(3241.36, 3241.36, 
                                                  3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23, 
                                                  3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72, 
                                                  4238.87), 
                         lagged_employment_total = c(15113.52, 15113.52, 
                                                     15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 
                                                     16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 
                                                     17365.32, 17650.21, 17951.61), 
                         lagged_value_be = c(47967.1, 
                                             47967.1, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 
                                             58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 
                                             72698.8, 75792.3), 
                         lagged_value_c = c(40192.9, 40192.9, 42014.6, 
                                            44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 51467.7, 
                                            53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196, 65063.5
                         ), 
                         lagged_value_j = c(7737.1, 7737.1, 7756.1, 8134.2, 8378.8, 
                                            8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 10361.4, 
                                            10695.4, 11455.3, 11720.6, 12871), 
                         lagged_value_k = c(10225.2, 
                                            10225.2, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 12123.9, 
                                            12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 13236.4, 
                                            13744.1, 14152.6), 
                         lagged_value_mn = c(15074, 15074, 16569.1, 
                                             18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 24895.4, 
                                             25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 33781.9
                         ), 
                         lagged_value_oq = c(35065.6, 35065.6, 37329.6, 38288.8, 
                                             40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5, 
                                             50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9), 
                         lagged_value_total = c(202353.5, 
                                                202353.5, 216098.3, 225888.1, 239076, 253604.6, 262414.7, 
                                                256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 307037.7, 
                                                318952.7, 329396.1, 344338.6), 
                         lagged_gdp_b1gq = c(226735.3, 
                                             226735.3, 242348.3, 254075, 267824.4, 283978, 293761.9, 288044.1, 
                                             295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 344269.3, 
                                             357608, 369341.3, 385361.9), 
                         lagged_gdp_p3 = c(164107.8, 
                                           164107.8, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
                                           213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
                                           249404.3, 257166.5, 265900.2, 274583.7), 
                         lagged_gdp_p61 = c(74691.6, 
                                            74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
                                            91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
                                            129183.6, 131524, 140057.8, 150278.2), 
                         lagged_gdp_p62 = c(28063.4, 
                                            28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
                                            39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
                                            59584.7, 64333.5), 
                         lagged_turnover_manu_dom = c(80, 80, 87, 
                                                      93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 107.1, 
                                                      104.7, 102.9, 107.9, 107.9), 
                         lagged_turnover_manu_non_dom = c(70.9, 
                                                          70.9, 81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7, 
                                                          112.8, 114.9, 118.2, 120.1, 129.2, 129.2), 
                         lagged_turnover_manu_tot = c(74.7, 
                                                      74.7, 84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 
                                                      111.9, 111.7, 112.6, 112.9, 120.3, 120.3), 
                         lagged_price_index = c(1.7, 
                                                1.7, 2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 
                                                0.8, 1, 2.2, 2.1), lagged_capital_n1132g = c(3638.4, 3638.4, 
                                                                                             3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 
                                                                                             4005.6, 3718.6, 3467.9, 4214.2, 4237.4, 4450.2, 4598.6), 
                         lagged_capital_n117g = c(8369.6, 8369.6, 8679.9, 8938.9, 
                                                  9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 13465.4, 13927.5, 
                                                  15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 19642.4), 
                         lagged_capital_n11mg = c(18749.6, 
                                                  18749.6, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
                                                  20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
                                                  29790.1, 30998), country = c("AT", "AT", "AT", "AT", "AT", 
                                                                               "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", 
                                                                               "AT", "AT")), 
                    row.names = c(NA, -17L), class = c("tbl_df", 
                                                       "tbl", "data.frame"))




set.seed(1234)

nn <- nrow(df_ini)

time_back <- 1

indices <-
  list(analysis   =  1:(nn-time_back) , 
       assessment = (nn-time_back+1):nn
  )

df_split <- make_splits(indices, df_ini)



df_train <- training(df_split)
df_test <- testing(df_split)

folded_data <- vfold_cv(df_train,3)

cubist_recipe <- 
  recipe(formula = berd ~ ., data = df_train) %>% 
  ## step_string2factor(one_of("country")) %>%
  update_role(year, new_role = "ID") %>%
  step_zv(all_predictors()) 

cubist_spec <- 
  cubist_rules(committees = tune(), neighbors = tune()) %>% 
  set_engine("Cubist") 

cubist_workflow <- 
  workflow() %>% 
  add_recipe(cubist_recipe) %>% 
  add_model(cubist_spec) 

cubist_grid <- tidyr::crossing(committees = c(1:9, (1:5) * 10),
                               neighbors = c(0, 3, 6, 9)) 

cubist_tune <- 
  tune_grid(cubist_workflow, resamples = folded_data, grid = cubist_grid) 
#> 
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#> 
#>     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
#>     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
#>     splice
#> 
#> Attaching package: 'vctrs'
#> The following object is masked from 'package:tibble':
#> 
#>     data_frame
#> The following object is masked from 'package:dplyr':
#> 
#>     data_frame
#> Loading required package: lattice

best_cub <- select_best(cubist_tune, "rmse")


final_cub <- finalize_workflow(
  cubist_workflow,
  best_cub
)

fit_model <- final_cub %>%
  fit(df_train)

summary(fit_model$fit$fit$fit)
#> 
#> Call:
#> cubist.default(x = x, y = y, committees = 1)
#> 
#> 
#> Cubist [Release 2.07 GPL Edition]  Thu Dec 10 16:52:59 2020
#> ---------------------------------
#> 
#>     Target attribute `outcome'
#> 
#> Read 16 cases (53 attributes) from undefined.data
#> 
#> Model:
#> 
#>   Rule 1: [16 cases, mean 5877.817, range 3130.884 to 8461.72, est err 251.023]
#> 
#>  outcome = -5043.087 + 0.0357 gdp_b1gq
#> 
#> 
#> Evaluation on training data (16 cases):
#> 
#>     Average  |error|            196.045
#>     Relative |error|               0.14
#>     Correlation coefficient        0.99
#> 
#> 
#>  Attribute usage:
#>    Conds  Model
#> 
#>           100%    gdp_b1gq
#> 
#> 
#> Time: 0.0 secs

reprex package (v0.3.0.9001)

于 2020-12-10 创建

如果你想得到系数来处理它们,请检查你从 as_tibble(fit_model$fit$fit$fit$coefficients) 得到的结果。