使用 tidymodels 预测 GAM 模型时出错
Error while predicting a GAM model using tidymodels
我想要什么: 我正在尝试使用 tidymodels
对给定数据拟合 GAM 模型进行分类。
到目前为止: 我能够拟合一个 logit 模型。
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
df_split <- initial_split(df, prop = 0.75, strata = class)
df_train <- training(df_split)
df_test <- testing(df_split)
log_model <- logistic_reg(mode = "classification",
engine = "glm") %>%
fit(class~duration, data = df_train)
predict(log_model, df_test)
#> # A tibble: 26 × 1
#> .pred_class
#> <fct>
#> 1 good
#> 2 good
#> 3 good
#> 4 bad
#> 5 good
#> 6 good
#> 7 bad
#> 8 bad
#> 9 good
#> 10 bad
#> # … with 16 more rows
我的问题: 令人惊讶的是,当我尝试 GAM 时出现错误。
gen_model <- gen_additive_mod(mode = "classification",
engine = "mgcv") %>%
fit(class~duration, data = df_train)
predict(gen_model, df_test)
#> Error: $ operator is invalid for atomic vectors
数据: 这是 df
数据帧的 dput
:
df <- structure(list(class = structure(c(2L, 1L, 2L, 2L, 2L, 2L, 2L,
2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 1L, 2L,
1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 2L, 2L, 1L,
1L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L,
2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 2L,
1L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L,
2L, 1L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 1L),
.Label = c("bad",
"good"), class = "factor"),
duration = c(42, 31.7869911119342,
18, 24, 12, 18, 10, 9, 12, 24, 10, 27, 14.4910072591156, 12,
48, 24, 30, 18, 6, 6, 12, 48, 10, 18, 6, 12, 24.4157173759304,
18, 48, 60, 18, 15, 9, 60, 24, 24, 9, 21, 26.4959116294049, 12,
5, 12, 12, 48, 18, 48, 12, 17.4877766738646, 36, 9, 15, 39.2811119947582,
27, 21, 24, 10, 6, 12, 12, 24, 39, 18, 24, 15, 48, 12, 24, 26.7659258879721,
36, 24, 27, 9, 12, 48, 28, 21, 6, 24, 24, 24, 18, 36, 36, 30,
8.19771710922942, 36, 18, 12, 13.8241796996444, 26.0928970947862,
10, 36, 12, 12, 24, 21.3157193372026, 18, 21, 24, 24)),
class = c("tbl_df",
"tbl", "data.frame"),
row.names = c(NA, -100L))
由 reprex package (v2.0.1)
于 2022 年 1 月 12 日创建
会话信息
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.1.2 (2021-11-01)
#> os macOS Big Sur 10.16
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate es_ES.UTF-8
#> ctype es_ES.UTF-8
#> tz Europe/Madrid
#> date 2022-01-12
#> pandoc 2.14.0.3 @ /Applications/RStudio.app/Contents/MacOS/pandoc/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.1.0)
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.1.0)
#> broom * 0.7.11 2022-01-03 [1] CRAN (R 4.1.2)
#> class 7.3-19 2021-05-03 [1] CRAN (R 4.1.2)
#> cli 3.1.0 2021-10-27 [1] CRAN (R 4.1.0)
#> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.1.2)
#> colorspace 2.0-2 2021-06-24 [1] CRAN (R 4.1.0)
#> crayon 1.4.2 2021-10-29 [1] CRAN (R 4.1.0)
#> DBI 1.1.2 2021-12-20 [1] CRAN (R 4.1.2)
#> dials * 0.0.10 2021-09-10 [1] CRAN (R 4.1.0)
#> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.1.0)
#> digest 0.6.29 2021-12-01 [1] CRAN (R 4.1.0)
#> dplyr * 1.0.7 2021-06-18 [1] CRAN (R 4.1.0)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.1.0)
#> evaluate 0.14 2019-05-28 [1] CRAN (R 4.1.0)
#> fansi 1.0.0 2022-01-10 [1] CRAN (R 4.1.2)
#> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.1.0)
#> foreach 1.5.1 2020-10-15 [1] CRAN (R 4.1.0)
#> fs 1.5.2 2021-12-08 [1] CRAN (R 4.1.0)
#> furrr 0.2.3 2021-06-25 [1] CRAN (R 4.1.0)
#> future 1.23.0 2021-10-31 [1] CRAN (R 4.1.0)
#> future.apply 1.8.1 2021-08-10 [1] CRAN (R 4.1.0)
#> generics 0.1.1 2021-10-25 [1] CRAN (R 4.1.0)
#> ggplot2 * 3.3.5 2021-06-25 [1] CRAN (R 4.1.0)
#> globals 0.14.0 2020-11-22 [1] CRAN (R 4.1.0)
#> glue 1.6.0 2021-12-17 [1] CRAN (R 4.1.0)
#> gower 0.2.2 2020-06-23 [1] CRAN (R 4.1.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.1.0)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.1.0)
#> hardhat 0.1.6 2021-07-14 [1] CRAN (R 4.1.0)
#> highr 0.9 2021-04-16 [1] CRAN (R 4.1.0)
#> htmltools 0.5.2 2021-08-25 [1] CRAN (R 4.1.0)
#> infer * 1.0.0 2021-08-13 [1] CRAN (R 4.1.0)
#> ipred 0.9-12 2021-09-15 [1] CRAN (R 4.1.0)
#> iterators 1.0.13 2020-10-15 [1] CRAN (R 4.1.0)
#> knitr 1.37 2021-12-16 [1] CRAN (R 4.1.0)
#> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.1.2)
#> lava 1.6.10 2021-09-02 [1] CRAN (R 4.1.0)
#> lhs 1.1.3 2021-09-08 [1] CRAN (R 4.1.0)
#> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.1.0)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.1.0)
#> lubridate 1.8.0 2021-10-07 [1] CRAN (R 4.1.0)
#> magrittr 2.0.1 2020-11-17 [1] CRAN (R 4.1.0)
#> MASS 7.3-54 2021-05-03 [1] CRAN (R 4.1.2)
#> Matrix 1.4-0 2021-12-08 [1] CRAN (R 4.1.0)
#> mgcv 1.8-38 2021-10-06 [1] CRAN (R 4.1.2)
#> modeldata * 0.1.1 2021-07-14 [1] CRAN (R 4.1.0)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.1.0)
#> nlme 3.1-153 2021-09-07 [1] CRAN (R 4.1.2)
#> nnet 7.3-16 2021-05-03 [1] CRAN (R 4.1.2)
#> parallelly 1.30.0 2021-12-17 [1] CRAN (R 4.1.0)
#> parsnip * 0.1.7 2021-07-21 [1] CRAN (R 4.1.0)
#> pillar 1.6.4 2021-10-18 [1] CRAN (R 4.1.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.1.0)
#> plyr 1.8.6 2020-03-03 [1] CRAN (R 4.1.0)
#> pROC 1.18.0 2021-09-03 [1] CRAN (R 4.1.0)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.1.0)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.1.0)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.1.0)
#> Rcpp 1.0.7 2021-07-07 [1] CRAN (R 4.1.0)
#> recipes * 0.1.17 2021-09-27 [1] CRAN (R 4.1.0)
#> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.1.0)
#> rlang 0.4.12 2021-10-18 [1] CRAN (R 4.1.0)
#> rmarkdown 2.11 2021-09-14 [1] CRAN (R 4.1.0)
#> rpart 4.1-15 2019-04-12 [1] CRAN (R 4.1.2)
#> rsample * 0.1.1 2021-11-08 [1] CRAN (R 4.1.0)
#> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.1.0)
#> scales * 1.1.1 2020-05-11 [1] CRAN (R 4.1.0)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.1.0)
#> stringi 1.7.6 2021-11-29 [1] CRAN (R 4.1.0)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.1.0)
#> survival 3.2-13 2021-08-24 [1] CRAN (R 4.1.2)
#> tibble * 3.1.6 2021-11-07 [1] CRAN (R 4.1.0)
#> tidymodels * 0.1.4.9000 2022-01-12 [1] Github (tidymodels/tidymodels@8486957)
#> tidyr * 1.1.4 2021-09-27 [1] CRAN (R 4.1.0)
#> tidyselect 1.1.1 2021-04-30 [1] CRAN (R 4.1.0)
#> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.1.0)
#> tune * 0.1.6 2021-07-21 [1] CRAN (R 4.1.0)
#> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.1.0)
#> vctrs 0.3.8 2021-04-29 [1] CRAN (R 4.1.0)
#> withr 2.4.3 2021-11-30 [1] CRAN (R 4.1.0)
#> workflows * 0.2.4 2021-10-12 [1] CRAN (R 4.1.0)
#> workflowsets * 0.1.0 2021-07-22 [1] CRAN (R 4.1.0)
#> xfun 0.29 2021-12-14 [1] CRAN (R 4.1.0)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.1.0)
#> yardstick * 0.0.9 2021-11-22 [1] CRAN (R 4.1.0)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library
#>
#> ──────────────────────────────────────────────────────────────────────────────
此问题已在开发版 {parsnip} (>0.1.7) 中得到解决。您可以通过 运行 remotes::install_github("tidymodels/parsnip")
.
安装它
library(parsnip)
library(rsample)
df <- structure(list(class = structure(c(2L, 1L, 2L, 2L, 2L, 2L, 2L,
2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 1L, 2L,
1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 2L, 2L, 1L,
1L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L,
2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 2L,
1L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L,
2L, 1L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 1L),
.Label = c("bad",
"good"), class = "factor"),
duration = c(42, 31.7869911119342,
18, 24, 12, 18, 10, 9, 12, 24, 10, 27, 14.4910072591156, 12,
48, 24, 30, 18, 6, 6, 12, 48, 10, 18, 6, 12, 24.4157173759304,
18, 48, 60, 18, 15, 9, 60, 24, 24, 9, 21, 26.4959116294049, 12,
5, 12, 12, 48, 18, 48, 12, 17.4877766738646, 36, 9, 15, 39.2811119947582,
27, 21, 24, 10, 6, 12, 12, 24, 39, 18, 24, 15, 48, 12, 24, 26.7659258879721,
36, 24, 27, 9, 12, 48, 28, 21, 6, 24, 24, 24, 18, 36, 36, 30,
8.19771710922942, 36, 18, 12, 13.8241796996444, 26.0928970947862,
10, 36, 12, 12, 24, 21.3157193372026, 18, 21, 24, 24)),
class = c("tbl_df",
"tbl", "data.frame"),
row.names = c(NA, -100L))
df_split <- initial_split(df, prop = 0.75, strata = class)
df_train <- training(df_split)
df_test <- testing(df_split)
gen_model <- gen_additive_mod(mode = "classification",
engine = "mgcv") %>%
fit(class~duration, data = df_train)
predict(gen_model, df_test)
#> # A tibble: 26 × 1
#> .pred_class
#> <fct>
#> 1 bad
#> 2 good
#> 3 good
#> 4 good
#> 5 bad
#> 6 good
#> 7 good
#> 8 good
#> 9 good
#> 10 good
#> # … with 16 more rows
由 reprex package (v2.0.1)
创建于 2022-01-12
我想要什么: 我正在尝试使用 tidymodels
对给定数据拟合 GAM 模型进行分类。
到目前为止: 我能够拟合一个 logit 模型。
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
df_split <- initial_split(df, prop = 0.75, strata = class)
df_train <- training(df_split)
df_test <- testing(df_split)
log_model <- logistic_reg(mode = "classification",
engine = "glm") %>%
fit(class~duration, data = df_train)
predict(log_model, df_test)
#> # A tibble: 26 × 1
#> .pred_class
#> <fct>
#> 1 good
#> 2 good
#> 3 good
#> 4 bad
#> 5 good
#> 6 good
#> 7 bad
#> 8 bad
#> 9 good
#> 10 bad
#> # … with 16 more rows
我的问题: 令人惊讶的是,当我尝试 GAM 时出现错误。
gen_model <- gen_additive_mod(mode = "classification",
engine = "mgcv") %>%
fit(class~duration, data = df_train)
predict(gen_model, df_test)
#> Error: $ operator is invalid for atomic vectors
数据: 这是 df
数据帧的 dput
:
df <- structure(list(class = structure(c(2L, 1L, 2L, 2L, 2L, 2L, 2L,
2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 1L, 2L,
1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 2L, 2L, 1L,
1L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L,
2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 2L,
1L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L,
2L, 1L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 1L),
.Label = c("bad",
"good"), class = "factor"),
duration = c(42, 31.7869911119342,
18, 24, 12, 18, 10, 9, 12, 24, 10, 27, 14.4910072591156, 12,
48, 24, 30, 18, 6, 6, 12, 48, 10, 18, 6, 12, 24.4157173759304,
18, 48, 60, 18, 15, 9, 60, 24, 24, 9, 21, 26.4959116294049, 12,
5, 12, 12, 48, 18, 48, 12, 17.4877766738646, 36, 9, 15, 39.2811119947582,
27, 21, 24, 10, 6, 12, 12, 24, 39, 18, 24, 15, 48, 12, 24, 26.7659258879721,
36, 24, 27, 9, 12, 48, 28, 21, 6, 24, 24, 24, 18, 36, 36, 30,
8.19771710922942, 36, 18, 12, 13.8241796996444, 26.0928970947862,
10, 36, 12, 12, 24, 21.3157193372026, 18, 21, 24, 24)),
class = c("tbl_df",
"tbl", "data.frame"),
row.names = c(NA, -100L))
由 reprex package (v2.0.1)
于 2022 年 1 月 12 日创建 会话信息sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.1.2 (2021-11-01)
#> os macOS Big Sur 10.16
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate es_ES.UTF-8
#> ctype es_ES.UTF-8
#> tz Europe/Madrid
#> date 2022-01-12
#> pandoc 2.14.0.3 @ /Applications/RStudio.app/Contents/MacOS/pandoc/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.1.0)
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.1.0)
#> broom * 0.7.11 2022-01-03 [1] CRAN (R 4.1.2)
#> class 7.3-19 2021-05-03 [1] CRAN (R 4.1.2)
#> cli 3.1.0 2021-10-27 [1] CRAN (R 4.1.0)
#> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.1.2)
#> colorspace 2.0-2 2021-06-24 [1] CRAN (R 4.1.0)
#> crayon 1.4.2 2021-10-29 [1] CRAN (R 4.1.0)
#> DBI 1.1.2 2021-12-20 [1] CRAN (R 4.1.2)
#> dials * 0.0.10 2021-09-10 [1] CRAN (R 4.1.0)
#> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.1.0)
#> digest 0.6.29 2021-12-01 [1] CRAN (R 4.1.0)
#> dplyr * 1.0.7 2021-06-18 [1] CRAN (R 4.1.0)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.1.0)
#> evaluate 0.14 2019-05-28 [1] CRAN (R 4.1.0)
#> fansi 1.0.0 2022-01-10 [1] CRAN (R 4.1.2)
#> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.1.0)
#> foreach 1.5.1 2020-10-15 [1] CRAN (R 4.1.0)
#> fs 1.5.2 2021-12-08 [1] CRAN (R 4.1.0)
#> furrr 0.2.3 2021-06-25 [1] CRAN (R 4.1.0)
#> future 1.23.0 2021-10-31 [1] CRAN (R 4.1.0)
#> future.apply 1.8.1 2021-08-10 [1] CRAN (R 4.1.0)
#> generics 0.1.1 2021-10-25 [1] CRAN (R 4.1.0)
#> ggplot2 * 3.3.5 2021-06-25 [1] CRAN (R 4.1.0)
#> globals 0.14.0 2020-11-22 [1] CRAN (R 4.1.0)
#> glue 1.6.0 2021-12-17 [1] CRAN (R 4.1.0)
#> gower 0.2.2 2020-06-23 [1] CRAN (R 4.1.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.1.0)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.1.0)
#> hardhat 0.1.6 2021-07-14 [1] CRAN (R 4.1.0)
#> highr 0.9 2021-04-16 [1] CRAN (R 4.1.0)
#> htmltools 0.5.2 2021-08-25 [1] CRAN (R 4.1.0)
#> infer * 1.0.0 2021-08-13 [1] CRAN (R 4.1.0)
#> ipred 0.9-12 2021-09-15 [1] CRAN (R 4.1.0)
#> iterators 1.0.13 2020-10-15 [1] CRAN (R 4.1.0)
#> knitr 1.37 2021-12-16 [1] CRAN (R 4.1.0)
#> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.1.2)
#> lava 1.6.10 2021-09-02 [1] CRAN (R 4.1.0)
#> lhs 1.1.3 2021-09-08 [1] CRAN (R 4.1.0)
#> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.1.0)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.1.0)
#> lubridate 1.8.0 2021-10-07 [1] CRAN (R 4.1.0)
#> magrittr 2.0.1 2020-11-17 [1] CRAN (R 4.1.0)
#> MASS 7.3-54 2021-05-03 [1] CRAN (R 4.1.2)
#> Matrix 1.4-0 2021-12-08 [1] CRAN (R 4.1.0)
#> mgcv 1.8-38 2021-10-06 [1] CRAN (R 4.1.2)
#> modeldata * 0.1.1 2021-07-14 [1] CRAN (R 4.1.0)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.1.0)
#> nlme 3.1-153 2021-09-07 [1] CRAN (R 4.1.2)
#> nnet 7.3-16 2021-05-03 [1] CRAN (R 4.1.2)
#> parallelly 1.30.0 2021-12-17 [1] CRAN (R 4.1.0)
#> parsnip * 0.1.7 2021-07-21 [1] CRAN (R 4.1.0)
#> pillar 1.6.4 2021-10-18 [1] CRAN (R 4.1.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.1.0)
#> plyr 1.8.6 2020-03-03 [1] CRAN (R 4.1.0)
#> pROC 1.18.0 2021-09-03 [1] CRAN (R 4.1.0)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.1.0)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.1.0)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.1.0)
#> Rcpp 1.0.7 2021-07-07 [1] CRAN (R 4.1.0)
#> recipes * 0.1.17 2021-09-27 [1] CRAN (R 4.1.0)
#> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.1.0)
#> rlang 0.4.12 2021-10-18 [1] CRAN (R 4.1.0)
#> rmarkdown 2.11 2021-09-14 [1] CRAN (R 4.1.0)
#> rpart 4.1-15 2019-04-12 [1] CRAN (R 4.1.2)
#> rsample * 0.1.1 2021-11-08 [1] CRAN (R 4.1.0)
#> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.1.0)
#> scales * 1.1.1 2020-05-11 [1] CRAN (R 4.1.0)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.1.0)
#> stringi 1.7.6 2021-11-29 [1] CRAN (R 4.1.0)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.1.0)
#> survival 3.2-13 2021-08-24 [1] CRAN (R 4.1.2)
#> tibble * 3.1.6 2021-11-07 [1] CRAN (R 4.1.0)
#> tidymodels * 0.1.4.9000 2022-01-12 [1] Github (tidymodels/tidymodels@8486957)
#> tidyr * 1.1.4 2021-09-27 [1] CRAN (R 4.1.0)
#> tidyselect 1.1.1 2021-04-30 [1] CRAN (R 4.1.0)
#> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.1.0)
#> tune * 0.1.6 2021-07-21 [1] CRAN (R 4.1.0)
#> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.1.0)
#> vctrs 0.3.8 2021-04-29 [1] CRAN (R 4.1.0)
#> withr 2.4.3 2021-11-30 [1] CRAN (R 4.1.0)
#> workflows * 0.2.4 2021-10-12 [1] CRAN (R 4.1.0)
#> workflowsets * 0.1.0 2021-07-22 [1] CRAN (R 4.1.0)
#> xfun 0.29 2021-12-14 [1] CRAN (R 4.1.0)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.1.0)
#> yardstick * 0.0.9 2021-11-22 [1] CRAN (R 4.1.0)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library
#>
#> ──────────────────────────────────────────────────────────────────────────────
此问题已在开发版 {parsnip} (>0.1.7) 中得到解决。您可以通过 运行 remotes::install_github("tidymodels/parsnip")
.
library(parsnip)
library(rsample)
df <- structure(list(class = structure(c(2L, 1L, 2L, 2L, 2L, 2L, 2L,
2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 1L, 2L,
1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 2L, 2L, 1L,
1L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L,
2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 2L,
1L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L,
2L, 1L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 1L),
.Label = c("bad",
"good"), class = "factor"),
duration = c(42, 31.7869911119342,
18, 24, 12, 18, 10, 9, 12, 24, 10, 27, 14.4910072591156, 12,
48, 24, 30, 18, 6, 6, 12, 48, 10, 18, 6, 12, 24.4157173759304,
18, 48, 60, 18, 15, 9, 60, 24, 24, 9, 21, 26.4959116294049, 12,
5, 12, 12, 48, 18, 48, 12, 17.4877766738646, 36, 9, 15, 39.2811119947582,
27, 21, 24, 10, 6, 12, 12, 24, 39, 18, 24, 15, 48, 12, 24, 26.7659258879721,
36, 24, 27, 9, 12, 48, 28, 21, 6, 24, 24, 24, 18, 36, 36, 30,
8.19771710922942, 36, 18, 12, 13.8241796996444, 26.0928970947862,
10, 36, 12, 12, 24, 21.3157193372026, 18, 21, 24, 24)),
class = c("tbl_df",
"tbl", "data.frame"),
row.names = c(NA, -100L))
df_split <- initial_split(df, prop = 0.75, strata = class)
df_train <- training(df_split)
df_test <- testing(df_split)
gen_model <- gen_additive_mod(mode = "classification",
engine = "mgcv") %>%
fit(class~duration, data = df_train)
predict(gen_model, df_test)
#> # A tibble: 26 × 1
#> .pred_class
#> <fct>
#> 1 bad
#> 2 good
#> 3 good
#> 4 good
#> 5 bad
#> 6 good
#> 7 good
#> 8 good
#> 9 good
#> 10 good
#> # … with 16 more rows
由 reprex package (v2.0.1)
创建于 2022-01-12