如何逐步提取mlr3调谐图?

How to extract mlr3 tuned graph step by step?

我的代码如下

library(mlr3verse)
library(mlr3pipelines)
library(mlr3filters)
library(paradox)
filter_importance = mlr_pipeops$get(
  "filter",
  filter = FilterImportance$new(learner = lrn("classif.ranger", importance = "impurity")),
  param_vals = list(filter.frac = 0.7)
)

learner_classif = lrn(
  "classif.ranger",
  predict_type = "prob",
  importance = "impurity",
  num.trees = 500
)
polrn_classif = PipeOpLearner$new(learner_classif)

# create learner graph 
glrn_classif = filter_importance %>>%  polrn_classif
glrn_classif = GraphLearner$new(glrn_classif)
glrn_classif$predict_type = "prob"

# task 

task = tsk("german_credit")

# set search_space
ps_classif = ParamSet$new(list(
  ParamInt$new("classif.ranger.num.trees", lower = 300, upper = 500),
  ParamDbl$new("classif.ranger.sample.fraction", lower = 0.7, upper = 0.8)
))

# auto tunning
at = AutoTuner$new(
  learner = glrn_classif, 
  resampling = rsmp("cv", folds = 3),
  measure = msr("classif.auc"), 
  search_space = ps_classif, 
  terminator = trm("evals", n_evals = 3), 
  tuner = tnr("random_search")
)

# sampling
rr = resample(task, at, rsmp("cv", folds = 2))

在我从重采样和受过训练的学习者 at 获得 rr 个对象之后。 请问如何提取这些步骤的作用?

例如:

非常感谢!!!

为了能够在重采样后使用模型 fiddle,最好使用 store_models = TRUE

调用重采样

使用你的例子

library(mlr3verse)

set.seed(1)
rr <- resample(task,
               at,
               rsmp("cv", folds = 2),
               store_models = TRUE)

完成重采样后,您可以像这样访问生成对象的内部结构:

获取每个折叠中的行 ID:

rr$resampling$instance
#output
      row_id fold
   1:      5    1
   2:      8    1
   3:      9    1
   4:     12    1
   5:     13    1
  ---            
 996:    989    2
 997:    993    2
 998:    994    2
 999:    995    2
1000:    996    2

有了这些和调整后的自动调谐器,我们可以手动生成预测。

生成测试索引列表

rsample <- split(rr$resampling$instance$row_id,
                 rr$resampling$instance$fold)

遍历折叠并调整自动调谐器并预测:

lapply(1:2, function(i){
  x <- rsample[[i]] #get the test row ids
  task_test <- task$clone() #clone the task so we don't change the original task
  task_test$filter(x) #filter on the test row ids
  preds <- rr$learners[[i]]$predict(task_test) #use the trained autotuner and above filtered task
  preds
  }) -> preds_manual

检查这些预测是否与重采样的输出匹​​配

all.equal(preds_manual,
          rr$predictions())
#output
TRUE

获取有关调整的信息

zz <- rr$data$learners()$learner

lapply(zz, function(x) x$tuning_result)
#output
[[1]]
   classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
1:                      342                      0.7931022          <list[7]>
    x_domain classif.auc
1: <list[2]>   0.7981283

[[2]]
   classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
1:                      407                      0.7964164          <list[7]>
    x_domain classif.auc
1: <list[2]>   0.7706533

插槽

zz[[1]]$learner$state$model$importance

包含有关 filter_importance 步骤的信息

特别是

lapply(zz, function(x) x$learner$state$model$importance$scores)
#output
[[1]]
                 amount                  status                     age 
              27.491369               25.776145               22.021369 
               duration                 purpose          credit_history 
              18.732521               16.251643               14.884843 
    employment_duration                 savings                property 
              11.225678               10.796583                9.078619 
    personal_status_sex       present_residence        installment_rate 
               8.914802                7.875384                7.491573 
                    job          number_credits other_installment_plans 
               6.293323                5.662485                5.345666 
                housing               telephone           other_debtors 
               4.869471                3.742213                3.548856 
          people_liable          foreign_worker 
               2.632163                1.054919 

[[2]]
                 amount                duration                     age 
              26.764389               22.139400               20.749865 
                 status                 purpose     employment_duration 
              20.524764               11.793789               10.962301 
         credit_history        installment_rate                 savings 
              10.416572                9.597835                9.491894 
               property       present_residence                     job 
               9.403157                7.877391                6.760945 
    personal_status_sex                 housing other_installment_plans 
               6.699065                5.811131                5.710761 
              telephone           other_debtors          number_credits 
               4.716322                4.318972                3.974793 
          people_liable          foreign_worker 
               3.196563                0.846520 

包含特征的排名。而

lapply(zz, function(x) x$learner$state$model$importance$outtasklayout)
#output
[[1]]
                     id    type
 1:                 age integer
 2:              amount integer
 3:      credit_history  factor
 4:            duration integer
 5: employment_duration  factor
 6:    installment_rate ordered
 7:                 job  factor
 8:      number_credits ordered
 9: personal_status_sex  factor
10:   present_residence ordered
11:            property  factor
12:             purpose  factor
13:             savings  factor
14:              status  factor

[[2]]
                     id    type
 1:                 age integer
 2:              amount integer
 3:      credit_history  factor
 4:            duration integer
 5: employment_duration  factor
 6:             housing  factor
 7:    installment_rate ordered
 8:                 job  factor
 9: personal_status_sex  factor
10:   present_residence ordered
11:            property  factor
12:             purpose  factor
13:             savings  factor
14:              status  factor

包含过滤步骤后保留的特征。