mlr3 使用自动调谐参数对新数据进行预测

mlr3 predictions to new data with parameters from autotune

我有一个 的后续问题。 与最初的问题一样,我正在使用 mlr3verse,有一个新数据集,并希望使用在自动调整期间表现良好的参数进行预测。该问题的答案是使用 at$train(task)。这似乎又开始调音了。它是否通过使用这些参数完全利用了嵌套重采样?

此外,查看$tuning_result 有两组参数,一组称为tune_x,一组称为params。这些有什么区别?

谢谢。

编辑: 下面添加了示例工作流程

library(mlr3verse)

set.seed(56624)
task = tsk("mtcars")

learner = lrn("regr.xgboost")

tune_ps = ParamSet$new(list(
  ParamDbl$new("eta", lower = .1, upper = .4),
  ParamInt$new("max_depth", lower = 2, upper = 4)
))

at = AutoTuner$new(learner = learner, 
                   resampling = rsmp("holdout"), # inner resampling
                   measures = msr("regr.mse"), 
                   tune_ps = tune_ps,
                   terminator = term("evals", n_evals = 3),
                   tuner = tnr("random_search"))

rr = resample(task = task, learner = at, resampling = rsmp("cv", folds = 2),
              store_models = TRUE)
rr$aggregate()
rr$score()
lapply(rr$learners, function(x) x$tuning_result)

at$train(task)
at$tuning_result

notreallynew.df = as.data.table(task)
at$predict_newdata(newdata = notreallynew.df)

正如 ?AutoTuner 所说,此 class 适合具有在调整期间找到的最佳超参数的模型。然后使用该模型进行预测,在您的情况下,在调用其方法 .$predict_newdata().

时用于 newdata

您还在 ?AutoTuner 中看到链接到 ?TuningInstance 的文档。然后这会告诉您 $tune_xparams 槽代表什么。下次尝试查看帮助页面 - 这就是他们在那里的目的;)

This seems to initiate tuning again.

为什么又是?它首先根据 task 的所有观察结果执行此操作。我假设您可能会对 "train/predict" 与 "resample" 之间的常见误解感到困惑。 阅读更多关于两者的理论差异,以了解两者在做什么。 他们有完全不同的目标,并且没有联系。 也许下面的代表更清楚。

library(mlr3verse)
#> Loading required package: mlr3
#> Loading required package: mlr3db
#> Loading required package: mlr3filters
#> Loading required package: mlr3learners
#> Loading required package: mlr3pipelines
#> Loading required package: mlr3tuning
#> Loading required package: mlr3viz
#> Loading required package: paradox

set.seed(56624)
lgr::get_logger("mlr3")$set_threshold("warn")
task = tsk("mtcars")
learner = lrn("regr.xgboost")

tune_ps = ParamSet$new(list(
  ParamDbl$new("eta", lower = .1, upper = .4),
  ParamInt$new("max_depth", lower = 2, upper = 4)
))

at = AutoTuner$new(
  learner = learner,
  resampling = rsmp("holdout"), # inner resampling
  measures = msr("regr.mse"),
  tune_ps = tune_ps,
  terminator = term("evals", n_evals = 3),
  tuner = tnr("random_search"))

# train/predict with AutoTuner
at$train(task)
notreallynew.df = as.data.table(task)
at$predict_newdata(newdata = notreallynew.df)
#> <PredictionRegr> for 32 observations:
#>     row_id truth response
#>          1  21.0 9.272631
#>          2  21.0 9.272631
#>          3  22.8 9.272631
#> ---                      
#>         30  19.7 9.272631
#>         31  15.0 5.875841
#>         32  21.4 9.272631

# resample with AutoTuner for performance estimation

rr = resample(
  task = task, learner = at, resampling = rsmp("cv", folds = 2),
  store_models = TRUE)
rr$aggregate()
#> regr.mse 
#> 240.5866
rr$score()
#>          task task_id     learner         learner_id     resampling
#> 1: <TaskRegr>  mtcars <AutoTuner> regr.xgboost.tuned <ResamplingCV>
#> 2: <TaskRegr>  mtcars <AutoTuner> regr.xgboost.tuned <ResamplingCV>
#>    resampling_id iteration prediction regr.mse
#> 1:            cv         1     <list> 220.8076
#> 2:            cv         2     <list> 260.3656

reprex package (v0.3.0)

于 2020-05-07 创建