tune_grid 中的错误,没有适用于 prep R tidymodels 的方法

Error in tune_grid with no applicable method for prep R tidymodels

我使用我为我的一个包构建的自定义步骤函数,它工作得很好,它在包的本地副本中(尚未提交给 CRAN)link 在这里运行: step_hai_fourier

这是会话信息(我们可以看到 healthyR.ai 0.0.2.9000 已加载):

> sessionInfo()
R version 4.1.0 (2021-05-18)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19042)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252   
[3] LC_MONETARY=English_United States.1252 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] timetk_2.6.1               tidyquant_1.0.3            quantmod_0.4.18           
 [4] TTR_0.24.2                 PerformanceAnalytics_2.0.4 xts_0.12.1                
 [7] zoo_1.8-9                  lubridate_1.8.0            forcats_0.5.1             
[10] stringr_1.4.0              readr_2.0.2                tidyverse_1.3.1           
[13] janitor_2.1.0              healthyR.ts_0.1.4          __healthyR.ai_0.0.2.9000__   
[16] yardstick_0.0.8            workflowsets_0.1.0         workflows_0.2.4           
[19] tune_0.1.6                 tidyr_1.1.4                tibble_3.1.6              
[22] rsample_0.1.1              recipes_0.1.17             purrr_0.3.4               
[25] parsnip_0.1.7              modeldata_0.1.1            infer_1.0.0               
[28] ggplot2_3.3.5              dplyr_1.0.7                dials_0.0.10              
[31] scales_1.1.1               broom_0.7.10               tidymodels_0.1.4          
[34] modeltime_1.1.0           

loaded via a namespace (and not attached):
  [1] readxl_1.3.1         backports_1.3.0      plyr_1.8.6           lazyeval_0.2.2      
  [5] splines_4.1.0        crosstalk_1.2.0      listenv_0.8.0        digest_0.6.28       
  [9] foreach_1.5.1        htmltools_0.5.2      fansi_0.5.0          checkmate_2.0.0     
 [13] magrittr_2.0.1       doParallel_1.0.16    tzdb_0.2.0           globals_0.14.0      
 [17] modelr_0.1.8         gower_0.2.2          RcppParallel_5.1.4   vroom_1.5.5         
 [21] hardhat_0.1.6        forecast_8.15        tseries_0.10-48      colorspace_2.0-2    
 [25] rvest_1.0.2          haven_2.4.3          crayon_1.4.2         jsonlite_1.7.2      
 [29] survival_3.2-11      iterators_1.0.13     glue_1.5.0           gtable_0.3.0        
 [33] ipred_0.9-12         Quandl_2.11.0        future.apply_1.8.1   DBI_1.1.1           
 [37] Rcpp_1.0.7           viridisLite_0.4.0    GPfit_1.0-8          bit_4.0.4           
 [41] lava_1.6.10          StanHeaders_2.21.0-7 prodlim_2019.11.13   htmlwidgets_1.5.4   
 [45] httr_1.4.2           ellipsis_0.3.2       pkgconfig_2.0.3      sass_0.4.0          
 [49] nnet_7.3-16          dbplyr_2.1.1         utf8_1.2.2           tidyselect_1.1.1    
 [53] labeling_0.4.2       rlang_0.4.12         DiceDesign_1.9       munsell_0.5.0       
 [57] cellranger_1.1.0     tools_4.1.0          xgboost_1.5.0.1      cli_3.1.0           
 [61] generics_0.1.1       fastmap_1.1.0        yaml_2.2.1           bit64_4.0.5         
 [65] fs_1.5.0             future_1.23.0        nlme_3.1-152         xml2_1.3.2          
 [69] compiler_4.1.0       rstudioapi_0.13      plotly_4.10.0        curl_4.3.2          
 [73] gt_0.3.1             reprex_2.0.1         lhs_1.1.3            stringi_1.7.5       
 [77] lattice_0.20-44      Matrix_1.3-4         urca_1.3-0           vctrs_0.3.8         
 [81] pillar_1.6.4         lifecycle_1.0.1      furrr_0.2.3          lmtest_0.9-39       
 [85] data.table_1.14.2    R6_2.5.1             parallelly_1.28.1    codetools_0.2-18    
 [89] MASS_7.3-54          assertthat_0.2.1     withr_2.4.2          fracdiff_1.5-1      
 [93] hms_1.1.1            quadprog_1.5-8       grid_4.1.0           rpart_4.1-15        
 [97] timeDate_3043.102    class_7.3-19         snakecase_0.11.0     pROC_1.18.0   

这是我的脚本,显示 prep 和“果汁”工作正常:

# Libraries ----
library(modeltime)
library(tidymodels)
library(healthyR.ai)
library(healthyR.ts)
library(parallel)
library(janitor)
library(tidyverse)
library(tidyquant)
library(timetk)

# Data ----
url      <- "https://cci30.com/ajax/getIndexHistory.php"
destfile <- "00_data/cci30_OHLCV.csv"
download.file(url, destfile = destfile)

cci_index_tbl <- read_csv("00_data/cci30_OHLCV.csv") %>%
  clean_names()

# * Daily Log Returns ----
time_param <- "weekly"
log_returns_tbl <- cci_index_tbl %>%
  tq_transmute(
    select = close
    , mutate_fun = periodReturn
    , period = time_param
    , type = "log"
    , col_rename = "value"
  ) %>%
  set_names("date_col", "value")

# * Train/Test ----
splits <- log_returns_tbl %>%
  time_series_split(
    date_var = date_col
    , assess = "12 weeks"
    , cumulative = TRUE
  )

splits %>%
  tk_time_series_cv_plan() %>%
  plot_time_series_cv_plan(
    .date_var = date_col
    , .value  = value
    , .title  = paste0(
      "CCI30 ", stringr::str_to_title(time_param), " Log Returns"
    )
  )

n_cores <- detectCores() - 1

# Recipe ----
recipe_base <- recipe(value ~ ., data = training(splits)) %>%
  step_hai_fourier(value, scale_type = "sincos", period = 12, order = 1)

recipe_base %>% prep() %>% juice()

控制台说:

> recipe_base %>% prep() %>% juice()
# A tibble: 347 x 3
   date_col      value value_sincos
   <date>        <dbl>        <dbl>
 1 2015-01-04 -0.196       -0.102  
 2 2015-01-11 -0.00811     -0.00425
 3 2015-01-18 -0.209       -0.108  
 4 2015-01-25  0.185        0.0961 
 5 2015-02-01 -0.141       -0.0736 
 6 2015-02-08 -0.0231      -0.0121 
 7 2015-02-15  0.0208       0.0109 
 8 2015-02-22 -0.0180      -0.00941
 9 2015-03-01  0.0841       0.0440 
10 2015-03-08 -0.0182      -0.00955
# ... with 337 more rows

到目前为止很完美,这是脚本的其余部分:

# Model ----
# Boosted Auto ARIMA
model_spec_arima_boosted <- arima_boost(
  min_n = 2
  , learn_rate = 0.015
) %>%
  set_engine(engine = "auto_arima_xgboost")

# Workflowset ----
wfsets <- workflow_set(
  preproc = list(
    base          = recipe_base
  ),
  models = list(
    model_spec_arima_boosted
  ),
  cross = TRUE
)

parallel_start(n_cores)
wf_fits <- wfsets %>% 
  modeltime_fit_workflowset(
    data = training(splits)
    , control = control_fit_workflowset(
      allow_par = TRUE
      , verbose = TRUE
    )
  )
parallel_stop()

wf_fits <- wf_fits %>%
  filter(.model != "NULL")

# Model Table -------------------------------------------------------------

models_tbl <- wf_fits

# Calibrate Model Testing -------------------------------------------------

parallel_start(n_cores)

calibration_tbl <- models_tbl %>%
  modeltime_calibrate(new_data = testing(splits))

parallel_stop()

calibration_tbl

# Testing Accuracy --------------------------------------------------------

parallel_start(n_cores)

calibration_tbl %>%
  modeltime_forecast(
    new_data    = testing(splits),
    actual_data = log_returns_tbl
  ) %>%
  plot_modeltime_forecast(
    .legend_max_width   = 25,
    .interactive        = TRUE,
    .conf_interval_show = FALSE
  )

parallel_stop()

calibration_tbl %>%
  modeltime_accuracy() %>%
  drop_na() %>%
  arrange(desc(rsq)) %>%
  table_modeltime_accuracy(.interactive = FALSE)

# Model Tuning ----
# Get Model
plucked_model <- calibration_tbl %>%
  modeltime::pluck_modeltime_model(1)

training_data <- rsample::training(splits)

tscv <- timetk::time_series_cv(
  data        = training_data,
  date_var    = date_col,
  cumulative  = TRUE,
  assess      = "26 weeks",
  skip        = "4 weeks",
  slice_limit = 6
)

# * Tune Spec ----
# Model Spec
model_spec <- plucked_model %>% parsnip::extract_spec_parsnip()
model_spec_engine <- model_spec[["engine"]]
model_spec_tuner <- healthyR.ts::ts_model_spec_tune_template(model_spec_engine)

# * Grid Spec ----
grid_spec <- dials::grid_latin_hypercube(
  tune::parameters(model_spec_tuner),
  size = 30
)

# * Tune Model ----
wflw_tune_spec <- plucked_model %>%
  workflows::update_model(model_spec_tuner)

# * Run Tuning Grid ----
modeltime::parallel_start(n_cores)

# THIS FAILS ----
tune_results <- wflw_tune_spec %>%
  tune::tune_grid(
    resamples = tscv,
    grid = grid_spec,
    metrics = modeltime::default_forecast_accuracy_metric_set(),
    control = tune::control_grid(
      verbose = TRUE,
      save_pred = TRUE
    )
  )

> tune_results$.notes[[1]]
# A tibble: 1 x 1
  .notes                                                                                            
  <chr>                                                                                             
1 "preprocessor 1/1: Error in UseMethod(\"prep\"): no applicable method for 'prep' applied to an ob~

不确定为什么会这样,也不确定这是属于这里还是食谱,我把它放在这里是因为它在 tun_grid.

中失败了

这是通过添加函数 .onLoad 来注册 s3 方法来解决的

提交:https://github.com/spsanderson/healthyR.ai/commit/89671ff138e61d07a5dbdfcd7e0a694144aa3e08

.onLoad = function(libname, pkgname) {
    maybe_register_s3_methods()
}
# nocov start
maybe_register_s3_methods <- function() {

    ns <- asNamespace("healthyR.ai")
    names <- names(ns)

    if (rlang::is_installed("tune") && utils::packageVersion("tune") >= "0.1.1.9000") {

        req_pkgs_names <- grep("^required_pkgs\.", names, value = TRUE)
        req_pkgs_classes <- gsub("required_pkgs.", "", req_pkgs_names)

        for (i in seq_along(req_pkgs_names)) {
            class <- req_pkgs_classes[[i]]
            s3_register("tune::required_pkgs", class)
        }
    }

    # ----------------------------------------------------------------------------

    invisible()
}

# vctrs:::s3_register()
s3_register <- function(generic, class, method = NULL) {
    stopifnot(is.character(generic), length(generic) == 1)
    stopifnot(is.character(class), length(class) == 1)

    pieces <- strsplit(generic, "::")[[1]]
    stopifnot(length(pieces) == 2)
    package <- pieces[[1]]
    generic <- pieces[[2]]

    caller <- parent.frame()

    get_method_env <- function() {
        top <- topenv(caller)
        if (isNamespace(top)) {
            asNamespace(environmentName(top))
        } else {
            caller
        }
    }
    get_method <- function(method, env) {
        if (is.null(method)) {
            get(paste0(generic, ".", class), envir = get_method_env())
        } else {
            method
        }
    }

    method_fn <- get_method(method)
    stopifnot(is.function(method_fn))

    # Always register hook in case package is later unloaded & reloaded
    setHook(
        packageEvent(package, "onLoad"),
        function(...) {
            ns <- asNamespace(package)

            # Refresh the method, it might have been updated by `devtools::load_all()`
            method_fn <- get_method(method)

            registerS3method(generic, class, method_fn, envir = ns)
        }
    )

    # Avoid registration failures during loading (pkgload or regular)
    if (!isNamespaceLoaded(package)) {
        return(invisible())
    }

    envir <- asNamespace(package)

    # Only register if generic can be accessed
    if (exists(generic, envir)) {
        registerS3method(generic, class, method_fn, envir = envir)
    }

    invisible()
}