通过静态-动态分支管道中的批处理提高并行性能

Improve parallel performance with batching in a static-dynamic branching pipeline

BLUF: 我正在努力了解如何使用 batching in the R targets 包来提高使用 [=12= 并行处理的静态和动态分支管道的性能].我想我需要在每个动态分支中进行批处理,但我不确定如何去做。

这是一个使用嵌套在静态分支中的动态分支的 reprex,类似于我的实际管道正在做的事情。它首先为 all_types 中的每个值静态分支,然后在每个类别内动态分支。此代码总共产生 1,000 个分支和 1,010 个目标。在实际工作流程中我显然没有使用replicate,动态分支的数量取决于type值。

# _targets.r

library(targets)
library(tarchetypes)
library(future)
library(future.callr)

plan(callr)

all_types = data.frame(type = LETTERS[1:10])

tar_map(values = all_types, names = "type",
  tar_target(
    make_data,
    replicate(100,
      data.frame(x = seq(1000) + rnorm(1000, 0, 5),
                 y = seq(1000) + rnorm(1000, 20, 20)),
      simplify = FALSE
    ),
    iteration = "list"
  ),
  tar_target(
    fit_model,
    lm(make_data),
    pattern = map(make_data),
    iteration = "list"
  )
)

这里是 tar_make()tar_make_future() 八个工人的时间比较:

# tar_destroy()
t1 <- system.time(tar_make())
# tar_destroy()
t2 <- system.time(tar_make_future(workers = 8))

rbind(serial = t1, parallel = t2)

##          user.self sys.self elapsed user.child sys.child
## serial        2.12     0.11   25.59         NA        NA
## parallel      2.07     0.24  184.68         NA        NA

我认为 usersystem 字段在这里没有用,因为作业被分派到单独的 R 进程,但是并行作业的 elapsed 时间大约需要比串行作业长 7 倍。

我认为这种减速是由于大量目标造成的。在这种情况下,批处理会提高性能吗?如果可以,我该如何在动态分支中实现批处理?

您在批处理方面走在了正确的轨道上。在您的情况下,这是将您的 100 个数据集列表分成 10 个左右的组的问题。您可以使用嵌套的数据集列表来执行此操作,但这需要大量工作。幸运的是,有一种更简单的方法。

你的问题问得真是恰逢其时。我刚刚在 tarchetypes 中写了一些新的目标工厂,它们可能会有所帮助。要访问它们,您需要来自 GitHub:

tarchetypes 的开发版本
remotes::install_github("ropensci/tarchetypes")

然后,使用 tar_map2_count(),为每个场景批处理 100 个数据集的列表会容易得多。

library(targets)
tar_script({
  library(broom)
  library(targets)
  library(tarchetypes)
  library(tibble)

  make_data <- function(n) {
    datasets_per_batch <- replicate(
      100,
      tibble(
        x = seq(n) + rnorm(n, 0, 5),
        y = seq(n) + rnorm(n, 20, 20)
      ),
      simplify = FALSE
    )
    tibble(dataset = datasets_per_batch, rep = seq_along(datasets_per_batch))
  }

  tar_map2_count(
    name = model,
    command1 = make_data(n = rows),
    command2 = tidy(lm(y ~ x, data = dataset)), # Need dataset[[1]] in tarchetypes 0.4.0
    values = data_frame(
      scenario = LETTERS[seq_len(10)],
      rows = seq(10, 100, length.out = 10)
    ),
    columns2 = NULL,
    batches = 10
  )
})
tar_make(reporter = "silent")
#> Warning message:
#> `data_frame()` was deprecated in tibble 1.1.0.
#> Please use `tibble()` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
tar_read(model)
#> # A tibble: 2,000 × 8
#>    term        estimate std.error statistic   p.value scenario  rows tar_group
#>    <chr>          <dbl>     <dbl>     <dbl>     <dbl> <chr>    <dbl>     <int>
#>  1 (Intercept)   17.1      12.8       1.34  0.218     A           10        10
#>  2 x              1.39      1.35      1.03  0.333     A           10        10
#>  3 (Intercept)    6.42     14.0       0.459 0.658     A           10        10
#>  4 x              1.75      1.28      1.37  0.209     A           10        10
#>  5 (Intercept)   32.8       7.14      4.60  0.00176   A           10        10
#>  6 x             -0.300     1.14     -0.263 0.799     A           10        10
#>  7 (Intercept)   29.7       3.24      9.18  0.0000160 A           10        10
#>  8 x              0.314     0.414     0.758 0.470     A           10        10
#>  9 (Intercept)   20.0      13.6       1.47  0.179     A           10        10
#> 10 x              1.23      1.77      0.698 0.505     A           10        10
#> # … with 1,990 more rows

reprex package (v2.0.1)

于 2021-12-10 创建

还有tar_map_rep(),如果你所有的数据集都是随机生成的,这可能会更容易,但我不确定我是否过度拟合你的用例。

library(targets)
tar_script({
  library(broom)
  library(targets)
  library(tarchetypes)
  library(tibble)

  make_one_dataset <- function(n) {
    tibble(
      x = seq(n) + rnorm(n, 0, 5),
      y = seq(n) + rnorm(n, 20, 20)
    )
  }

  tar_map_rep(
    name = model,
    command = tidy(lm(y ~ x, data = make_one_dataset(n = rows))),
    values = data_frame(
      scenario = LETTERS[seq_len(10)],
      rows = seq(10, 100, length.out = 10)
    ),
    batches = 10,
    reps = 10
  )
})
tar_make(reporter = "silent")
#> Warning message:
#> `data_frame()` was deprecated in tibble 1.1.0.
#> Please use `tibble()` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
tar_read(model)
#> # A tibble: 2,000 × 10
#>    term    estimate std.error statistic p.value scenario  rows tar_batch tar_rep
#>    <chr>      <dbl>     <dbl>     <dbl>   <dbl> <chr>    <dbl>     <int>   <int>
#>  1 (Inter…   37.5        7.50     5.00  0.00105 A           10         1       1
#>  2 x         -0.701      1.17    -0.601 0.564   A           10         1       1
#>  3 (Inter…   21.5        9.64     2.23  0.0567  A           10         1       2
#>  4 x         -0.213      1.55    -0.138 0.894   A           10         1       2
#>  5 (Inter…   20.6        9.51     2.17  0.0620  A           10         1       3
#>  6 x          1.40       1.79     0.783 0.456   A           10         1       3
#>  7 (Inter…   11.6       11.2      1.04  0.329   A           10         1       4
#>  8 x          2.34       1.39     1.68  0.131   A           10         1       4
#>  9 (Inter…   26.8        9.16     2.93  0.0191  A           10         1       5
#> 10 x          0.288      1.10     0.262 0.800   A           10         1       5
#> # … with 1,990 more rows, and 1 more variable: tar_group <int>

reprex package (v2.0.1)

于 2021-12-10 创建

不幸的是,期货确实有开销。如果你尝试 tar_make_clustermq()?

也许你的情况会更快