从 fit_resamples() 结果中提取模型 object 的最快方法
Fastest way to extract a model object from fit_resamples() results
这个问题是针对 tidymodels 用户的,如果你比较懒,直接跳过整个文本直接跳到下面的粗体问题
我正在寻找从拟合重采样 (tune::fit_resample()
) 中提取防风草模型 object 的最有效方法。
当我想用 cross-validation 训练模型时,我可以选择 tune::tune_grid()
或 fit_resamples()
。
假设我知道我的算法的最佳参数,所以我不需要任何参数调整,这意味着我决定使用 fit_resamples()
。
如果我决定使用 tune_grid()
,我通常会设置一个工作流程,因为我在 tune_grid 运行 之后评估不同的模型:我选择 tune::show_best()
和 tune::select_best()
为我的模型探索和提取最佳参数。然后我去 tune::finalize_workflow()
,workflows::pull_wokrflow_fit()
提取我的模型 object。此外,当我想查看预测时,我会选择 tune::last_fit()
和 tune::collect_predictions()
当我使用 fit_resamples()
时,所有这些步骤似乎都是多余的,因为我基本上只有一个具有稳定参数的模型。所以上面的所有这些步骤都不是必需的,但我必须通过它们。我呢?
执行 fit_resamples()
后,我得到一个小标题,其中包含有关 .splits、.metrics、.notes 等的信息。
所以我的问题真的归结为:
- 从
fit_resamples()
的输出小标题到我最终的防风草模型 object 的最快方法是什么?
了解 fit_resamples()
的重要一点是它的目的是 衡量绩效 。您在 fit_resamples()
中训练的模型不会保留或以后使用。
假设您知道要用于 SVM 模型的参数。
library(tidymodels)
#> ── Attaching packages ─────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom 0.7.0 ✓ recipes 0.1.13
#> ✓ dials 0.0.8 ✓ rsample 0.0.7
#> ✓ dplyr 1.0.0 ✓ tibble 3.0.3
#> ✓ ggplot2 3.3.2 ✓ tidyr 1.1.0
#> ✓ infer 0.5.3 ✓ tune 0.1.1
#> ✓ modeldata 0.0.2 ✓ workflows 0.1.2
#> ✓ parsnip 0.1.2 ✓ yardstick 0.0.7
#> ✓ purrr 0.3.4
#> ── Conflicts ────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
## pretend this is your training data
data("hpc_data")
svm_spec <- svm_poly(degree = 1, cost = 1/4) %>%
set_engine("kernlab") %>%
set_mode("regression")
svm_wf <- workflow() %>%
add_model(svm_spec) %>%
add_formula(compounds ~ .)
hpc_folds <- vfold_cv(hpc_data)
svm_rs <- svm_wf %>%
fit_resamples(
resamples = hpc_folds
)
svm_rs
#> # Resampling results
#> # 10-fold cross-validation
#> # A tibble: 10 x 4
#> splits id .metrics .notes
#> <list> <chr> <list> <list>
#> 1 <split [3.9K/434]> Fold01 <tibble [2 × 3]> <tibble [0 × 1]>
#> 2 <split [3.9K/433]> Fold02 <tibble [2 × 3]> <tibble [0 × 1]>
#> 3 <split [3.9K/433]> Fold03 <tibble [2 × 3]> <tibble [0 × 1]>
#> 4 <split [3.9K/433]> Fold04 <tibble [2 × 3]> <tibble [0 × 1]>
#> 5 <split [3.9K/433]> Fold05 <tibble [2 × 3]> <tibble [0 × 1]>
#> 6 <split [3.9K/433]> Fold06 <tibble [2 × 3]> <tibble [0 × 1]>
#> 7 <split [3.9K/433]> Fold07 <tibble [2 × 3]> <tibble [0 × 1]>
#> 8 <split [3.9K/433]> Fold08 <tibble [2 × 3]> <tibble [0 × 1]>
#> 9 <split [3.9K/433]> Fold09 <tibble [2 × 3]> <tibble [0 × 1]>
#> 10 <split [3.9K/433]> Fold10 <tibble [2 × 3]> <tibble [0 × 1]>
此输出中没有拟合模型。每个重采样都安装了模型,但您不想将它们用于任何事情;它们被丢弃是因为它们的唯一目的是计算 .metrics
以估计性能。
如果您希望模型用于预测新数据,您需要返回到整个训练集并使用整个训练集再次拟合模型。
svm_fit <- svm_wf %>%
fit(hpc_data)
svm_fit
#> ══ Workflow [trained] ═══════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: svm_poly()
#>
#> ── Preprocessor ─────────────────────────────────────────────────────
#> compounds ~ .
#>
#> ── Model ────────────────────────────────────────────────────────────
#> Support Vector Machine object of class "ksvm"
#>
#> SV type: eps-svr (regression)
#> parameter : epsilon = 0.1 cost C = 0.25
#>
#> Polynomial kernel function.
#> Hyperparameters : degree = 1 scale = 1 offset = 1
#>
#> Number of Support Vectors : 2827
#>
#> Objective Function Value : -284.7255
#> Training error : 0.835421
由 reprex package (v0.3.0)
于 2020-07-17 创建
最后一个对象可以与 pull_workflow_fit()
一起用于可变重要性或类似对象。
这个问题是针对 tidymodels 用户的,如果你比较懒,直接跳过整个文本直接跳到下面的粗体问题
我正在寻找从拟合重采样 (tune::fit_resample()
) 中提取防风草模型 object 的最有效方法。
当我想用 cross-validation 训练模型时,我可以选择 tune::tune_grid()
或 fit_resamples()
。
假设我知道我的算法的最佳参数,所以我不需要任何参数调整,这意味着我决定使用 fit_resamples()
。
如果我决定使用 tune_grid()
,我通常会设置一个工作流程,因为我在 tune_grid 运行 之后评估不同的模型:我选择 tune::show_best()
和 tune::select_best()
为我的模型探索和提取最佳参数。然后我去 tune::finalize_workflow()
,workflows::pull_wokrflow_fit()
提取我的模型 object。此外,当我想查看预测时,我会选择 tune::last_fit()
和 tune::collect_predictions()
当我使用 fit_resamples()
时,所有这些步骤似乎都是多余的,因为我基本上只有一个具有稳定参数的模型。所以上面的所有这些步骤都不是必需的,但我必须通过它们。我呢?
执行 fit_resamples()
后,我得到一个小标题,其中包含有关 .splits、.metrics、.notes 等的信息。
所以我的问题真的归结为:
- 从
fit_resamples()
的输出小标题到我最终的防风草模型 object 的最快方法是什么?
了解 fit_resamples()
的重要一点是它的目的是 衡量绩效 。您在 fit_resamples()
中训练的模型不会保留或以后使用。
假设您知道要用于 SVM 模型的参数。
library(tidymodels)
#> ── Attaching packages ─────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom 0.7.0 ✓ recipes 0.1.13
#> ✓ dials 0.0.8 ✓ rsample 0.0.7
#> ✓ dplyr 1.0.0 ✓ tibble 3.0.3
#> ✓ ggplot2 3.3.2 ✓ tidyr 1.1.0
#> ✓ infer 0.5.3 ✓ tune 0.1.1
#> ✓ modeldata 0.0.2 ✓ workflows 0.1.2
#> ✓ parsnip 0.1.2 ✓ yardstick 0.0.7
#> ✓ purrr 0.3.4
#> ── Conflicts ────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
## pretend this is your training data
data("hpc_data")
svm_spec <- svm_poly(degree = 1, cost = 1/4) %>%
set_engine("kernlab") %>%
set_mode("regression")
svm_wf <- workflow() %>%
add_model(svm_spec) %>%
add_formula(compounds ~ .)
hpc_folds <- vfold_cv(hpc_data)
svm_rs <- svm_wf %>%
fit_resamples(
resamples = hpc_folds
)
svm_rs
#> # Resampling results
#> # 10-fold cross-validation
#> # A tibble: 10 x 4
#> splits id .metrics .notes
#> <list> <chr> <list> <list>
#> 1 <split [3.9K/434]> Fold01 <tibble [2 × 3]> <tibble [0 × 1]>
#> 2 <split [3.9K/433]> Fold02 <tibble [2 × 3]> <tibble [0 × 1]>
#> 3 <split [3.9K/433]> Fold03 <tibble [2 × 3]> <tibble [0 × 1]>
#> 4 <split [3.9K/433]> Fold04 <tibble [2 × 3]> <tibble [0 × 1]>
#> 5 <split [3.9K/433]> Fold05 <tibble [2 × 3]> <tibble [0 × 1]>
#> 6 <split [3.9K/433]> Fold06 <tibble [2 × 3]> <tibble [0 × 1]>
#> 7 <split [3.9K/433]> Fold07 <tibble [2 × 3]> <tibble [0 × 1]>
#> 8 <split [3.9K/433]> Fold08 <tibble [2 × 3]> <tibble [0 × 1]>
#> 9 <split [3.9K/433]> Fold09 <tibble [2 × 3]> <tibble [0 × 1]>
#> 10 <split [3.9K/433]> Fold10 <tibble [2 × 3]> <tibble [0 × 1]>
此输出中没有拟合模型。每个重采样都安装了模型,但您不想将它们用于任何事情;它们被丢弃是因为它们的唯一目的是计算 .metrics
以估计性能。
如果您希望模型用于预测新数据,您需要返回到整个训练集并使用整个训练集再次拟合模型。
svm_fit <- svm_wf %>%
fit(hpc_data)
svm_fit
#> ══ Workflow [trained] ═══════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: svm_poly()
#>
#> ── Preprocessor ─────────────────────────────────────────────────────
#> compounds ~ .
#>
#> ── Model ────────────────────────────────────────────────────────────
#> Support Vector Machine object of class "ksvm"
#>
#> SV type: eps-svr (regression)
#> parameter : epsilon = 0.1 cost C = 0.25
#>
#> Polynomial kernel function.
#> Hyperparameters : degree = 1 scale = 1 offset = 1
#>
#> Number of Support Vectors : 2827
#>
#> Objective Function Value : -284.7255
#> Training error : 0.835421
由 reprex package (v0.3.0)
于 2020-07-17 创建最后一个对象可以与 pull_workflow_fit()
一起用于可变重要性或类似对象。