使用 tidymodels 预测 GAM 模型时出错

Error while predicting a GAM model using tidymodels

我想要什么: 我正在尝试使用 tidymodels 对给定数据拟合 GAM 模型进行分类。

到目前为止: 我能够拟合一个 logit 模型。

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip


df_split <- initial_split(df, prop = 0.75, strata = class)
df_train <- training(df_split)
df_test <- testing(df_split)

log_model <- logistic_reg(mode = "classification",
                          engine = "glm") %>%
  fit(class~duration, data = df_train)

predict(log_model, df_test)
#> # A tibble: 26 × 1
#>    .pred_class
#>    <fct>      
#>  1 good       
#>  2 good       
#>  3 good       
#>  4 bad        
#>  5 good       
#>  6 good       
#>  7 bad        
#>  8 bad        
#>  9 good       
#> 10 bad        
#> # … with 16 more rows

我的问题: 令人惊讶的是,当我尝试 GAM 时出现错误。

gen_model <- gen_additive_mod(mode = "classification",
                              engine = "mgcv") %>%
  fit(class~duration, data = df_train)

predict(gen_model, df_test)
#> Error: $ operator is invalid for atomic vectors

数据: 这是 df 数据帧的 dput

df <- structure(list(class = structure(c(2L, 1L, 2L, 2L, 2L, 2L, 2L, 
                                         2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 1L, 2L, 
                                         1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 2L, 2L, 1L, 
                                         1L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 
                                         2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 
                                         1L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L, 
                                         2L, 1L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 1L), 
                                       .Label = c("bad", 
                                                  "good"), class = "factor"), 
                     duration = c(42, 31.7869911119342, 
                                  18, 24, 12, 18, 10, 9, 12, 24, 10, 27, 14.4910072591156, 12, 
                                  48, 24, 30, 18, 6, 6, 12, 48, 10, 18, 6, 12, 24.4157173759304, 
                                  18, 48, 60, 18, 15, 9, 60, 24, 24, 9, 21, 26.4959116294049, 12, 
                                  5, 12, 12, 48, 18, 48, 12, 17.4877766738646, 36, 9, 15, 39.2811119947582, 
                                  27, 21, 24, 10, 6, 12, 12, 24, 39, 18, 24, 15, 48, 12, 24, 26.7659258879721, 
                                  36, 24, 27, 9, 12, 48, 28, 21, 6, 24, 24, 24, 18, 36, 36, 30, 
                                  8.19771710922942, 36, 18, 12, 13.8241796996444, 26.0928970947862, 
                                  10, 36, 12, 12, 24, 21.3157193372026, 18, 21, 24, 24)), 
                class = c("tbl_df", 
                          "tbl", "data.frame"), 
                row.names = c(NA, -100L))

reprex package (v2.0.1)

于 2022 年 1 月 12 日创建 会话信息
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.1.2 (2021-11-01)
#>  os       macOS Big Sur 10.16
#>  system   x86_64, darwin17.0
#>  ui       X11
#>  language (EN)
#>  collate  es_ES.UTF-8
#>  ctype    es_ES.UTF-8
#>  tz       Europe/Madrid
#>  date     2022-01-12
#>  pandoc   2.14.0.3 @ /Applications/RStudio.app/Contents/MacOS/pandoc/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.1.0)
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.1.0)
#>  broom        * 0.7.11     2022-01-03 [1] CRAN (R 4.1.2)
#>  class          7.3-19     2021-05-03 [1] CRAN (R 4.1.2)
#>  cli            3.1.0      2021-10-27 [1] CRAN (R 4.1.0)
#>  codetools      0.2-18     2020-11-04 [1] CRAN (R 4.1.2)
#>  colorspace     2.0-2      2021-06-24 [1] CRAN (R 4.1.0)
#>  crayon         1.4.2      2021-10-29 [1] CRAN (R 4.1.0)
#>  DBI            1.1.2      2021-12-20 [1] CRAN (R 4.1.2)
#>  dials        * 0.0.10     2021-09-10 [1] CRAN (R 4.1.0)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.1.0)
#>  digest         0.6.29     2021-12-01 [1] CRAN (R 4.1.0)
#>  dplyr        * 1.0.7      2021-06-18 [1] CRAN (R 4.1.0)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.1.0)
#>  evaluate       0.14       2019-05-28 [1] CRAN (R 4.1.0)
#>  fansi          1.0.0      2022-01-10 [1] CRAN (R 4.1.2)
#>  fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.1.0)
#>  foreach        1.5.1      2020-10-15 [1] CRAN (R 4.1.0)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.1.0)
#>  furrr          0.2.3      2021-06-25 [1] CRAN (R 4.1.0)
#>  future         1.23.0     2021-10-31 [1] CRAN (R 4.1.0)
#>  future.apply   1.8.1      2021-08-10 [1] CRAN (R 4.1.0)
#>  generics       0.1.1      2021-10-25 [1] CRAN (R 4.1.0)
#>  ggplot2      * 3.3.5      2021-06-25 [1] CRAN (R 4.1.0)
#>  globals        0.14.0     2020-11-22 [1] CRAN (R 4.1.0)
#>  glue           1.6.0      2021-12-17 [1] CRAN (R 4.1.0)
#>  gower          0.2.2      2020-06-23 [1] CRAN (R 4.1.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.1.0)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.1.0)
#>  hardhat        0.1.6      2021-07-14 [1] CRAN (R 4.1.0)
#>  highr          0.9        2021-04-16 [1] CRAN (R 4.1.0)
#>  htmltools      0.5.2      2021-08-25 [1] CRAN (R 4.1.0)
#>  infer        * 1.0.0      2021-08-13 [1] CRAN (R 4.1.0)
#>  ipred          0.9-12     2021-09-15 [1] CRAN (R 4.1.0)
#>  iterators      1.0.13     2020-10-15 [1] CRAN (R 4.1.0)
#>  knitr          1.37       2021-12-16 [1] CRAN (R 4.1.0)
#>  lattice        0.20-45    2021-09-22 [1] CRAN (R 4.1.2)
#>  lava           1.6.10     2021-09-02 [1] CRAN (R 4.1.0)
#>  lhs            1.1.3      2021-09-08 [1] CRAN (R 4.1.0)
#>  lifecycle      1.0.1      2021-09-24 [1] CRAN (R 4.1.0)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.1.0)
#>  lubridate      1.8.0      2021-10-07 [1] CRAN (R 4.1.0)
#>  magrittr       2.0.1      2020-11-17 [1] CRAN (R 4.1.0)
#>  MASS           7.3-54     2021-05-03 [1] CRAN (R 4.1.2)
#>  Matrix         1.4-0      2021-12-08 [1] CRAN (R 4.1.0)
#>  mgcv           1.8-38     2021-10-06 [1] CRAN (R 4.1.2)
#>  modeldata    * 0.1.1      2021-07-14 [1] CRAN (R 4.1.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.1.0)
#>  nlme           3.1-153    2021-09-07 [1] CRAN (R 4.1.2)
#>  nnet           7.3-16     2021-05-03 [1] CRAN (R 4.1.2)
#>  parallelly     1.30.0     2021-12-17 [1] CRAN (R 4.1.0)
#>  parsnip      * 0.1.7      2021-07-21 [1] CRAN (R 4.1.0)
#>  pillar         1.6.4      2021-10-18 [1] CRAN (R 4.1.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.1.0)
#>  plyr           1.8.6      2020-03-03 [1] CRAN (R 4.1.0)
#>  pROC           1.18.0     2021-09-03 [1] CRAN (R 4.1.0)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.1.0)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.1.0)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.1.0)
#>  Rcpp           1.0.7      2021-07-07 [1] CRAN (R 4.1.0)
#>  recipes      * 0.1.17     2021-09-27 [1] CRAN (R 4.1.0)
#>  reprex         2.0.1      2021-08-05 [1] CRAN (R 4.1.0)
#>  rlang          0.4.12     2021-10-18 [1] CRAN (R 4.1.0)
#>  rmarkdown      2.11       2021-09-14 [1] CRAN (R 4.1.0)
#>  rpart          4.1-15     2019-04-12 [1] CRAN (R 4.1.2)
#>  rsample      * 0.1.1      2021-11-08 [1] CRAN (R 4.1.0)
#>  rstudioapi     0.13       2020-11-12 [1] CRAN (R 4.1.0)
#>  scales       * 1.1.1      2020-05-11 [1] CRAN (R 4.1.0)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.1.0)
#>  stringi        1.7.6      2021-11-29 [1] CRAN (R 4.1.0)
#>  stringr        1.4.0      2019-02-10 [1] CRAN (R 4.1.0)
#>  survival       3.2-13     2021-08-24 [1] CRAN (R 4.1.2)
#>  tibble       * 3.1.6      2021-11-07 [1] CRAN (R 4.1.0)
#>  tidymodels   * 0.1.4.9000 2022-01-12 [1] Github (tidymodels/tidymodels@8486957)
#>  tidyr        * 1.1.4      2021-09-27 [1] CRAN (R 4.1.0)
#>  tidyselect     1.1.1      2021-04-30 [1] CRAN (R 4.1.0)
#>  timeDate       3043.102   2018-02-21 [1] CRAN (R 4.1.0)
#>  tune         * 0.1.6      2021-07-21 [1] CRAN (R 4.1.0)
#>  utf8           1.2.2      2021-07-24 [1] CRAN (R 4.1.0)
#>  vctrs          0.3.8      2021-04-29 [1] CRAN (R 4.1.0)
#>  withr          2.4.3      2021-11-30 [1] CRAN (R 4.1.0)
#>  workflows    * 0.2.4      2021-10-12 [1] CRAN (R 4.1.0)
#>  workflowsets * 0.1.0      2021-07-22 [1] CRAN (R 4.1.0)
#>  xfun           0.29       2021-12-14 [1] CRAN (R 4.1.0)
#>  yaml           2.2.1      2020-02-01 [1] CRAN (R 4.1.0)
#>  yardstick    * 0.0.9      2021-11-22 [1] CRAN (R 4.1.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

此问题已在开发版 {parsnip} (>0.1.7) 中得到解决。您可以通过 运行 remotes::install_github("tidymodels/parsnip").

安装它
library(parsnip)
library(rsample)

df <- structure(list(class = structure(c(2L, 1L, 2L, 2L, 2L, 2L, 2L, 
                                         2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 1L, 2L, 
                                         1L, 2L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 2L, 2L, 1L, 
                                         1L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 
                                         2L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 
                                         1L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 1L, 1L, 
                                         2L, 1L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 1L), 
                                       .Label = c("bad", 
                                                  "good"), class = "factor"), 
                     duration = c(42, 31.7869911119342, 
                                  18, 24, 12, 18, 10, 9, 12, 24, 10, 27, 14.4910072591156, 12, 
                                  48, 24, 30, 18, 6, 6, 12, 48, 10, 18, 6, 12, 24.4157173759304, 
                                  18, 48, 60, 18, 15, 9, 60, 24, 24, 9, 21, 26.4959116294049, 12, 
                                  5, 12, 12, 48, 18, 48, 12, 17.4877766738646, 36, 9, 15, 39.2811119947582, 
                                  27, 21, 24, 10, 6, 12, 12, 24, 39, 18, 24, 15, 48, 12, 24, 26.7659258879721, 
                                  36, 24, 27, 9, 12, 48, 28, 21, 6, 24, 24, 24, 18, 36, 36, 30, 
                                  8.19771710922942, 36, 18, 12, 13.8241796996444, 26.0928970947862, 
                                  10, 36, 12, 12, 24, 21.3157193372026, 18, 21, 24, 24)), 
                class = c("tbl_df", 
                          "tbl", "data.frame"), 
                row.names = c(NA, -100L))

df_split <- initial_split(df, prop = 0.75, strata = class)
df_train <- training(df_split)
df_test <- testing(df_split)

gen_model <- gen_additive_mod(mode = "classification",
                              engine = "mgcv") %>%
  fit(class~duration, data = df_train)

predict(gen_model, df_test)
#> # A tibble: 26 × 1
#>    .pred_class
#>    <fct>      
#>  1 bad        
#>  2 good       
#>  3 good       
#>  4 good       
#>  5 bad        
#>  6 good       
#>  7 good       
#>  8 good       
#>  9 good       
#> 10 good       
#> # … with 16 more rows

reprex package (v2.0.1)

创建于 2022-01-12