如何从数据框中构造 case_when 的参数?

How to construct arguments for case_when from data frame?

我正在尝试根据温度创建许多不同的可能加权方案。

我创建了一个数据框,其中包含 8 个向量的所有可能组合(每个向量代表一个温度范围)。所以数据框的列是特定的温度范围,行是权重。

我想将温度范围作为参数传递给 case_when,并遍历权重数据框的每一行,根据实际温度和相关权重为每一行创建一个新变量该温度基于权重数据框中的信息。

使用以下 post,我能够创建一个函数来生成权重数据框:

但我不知道如何使用权重数据框构建 case_when 参数。

创建所有可能权重的数据框的函数

library(rlang)
library(tidyverse)

create_temp_weights <- function(
  from = 31,
  to = 100,
  by = 10,
  weights = exprs(between(., 31, 40) ~ c(0, 0.2),
                  between(., 41, 50) ~ c(0.5, 0.8),
                  between(., 51, 90) ~ c(0.8, 1),
                  between(., 91, 100) ~ c(0.2, 0.8),
                  TRUE ~ c(-0.1, 0))
) {

  # use 999 to map other temperatures to last case
  map(c(seq(from, to, by), 999), ~ case_when(!!!weights)) %>%
    set_names(c(map_chr(seq(from, to, by),
                      ~ str_c("temp_", ., "_", . + by - 1)), "temp_other")) %>%
  cross_df(.)

}

temp_weights <- create_temp_weights()

使用用于构建权重的温度向量创建 tibble

test_tibble <- tibble(temp = seq_len(100))

head(test_tibble)

以下 case_when 是我尝试使用权重数据框以编程方式生成的内容。

# Now I want to create a function that will produce the following
# case_when from the temp_weight data frame so I don't have to
# manually edit the following each time I create a new weights data frame

test_tibble2 <- map_dfc(.x = seq_len(nrow(temp_weights)),
    ~ transmute(
      test_tibble,
      temp =
        case_when(
          temp >= 31   & temp  <= 40   ~  temp_weights$temp_31_40[.x],
          temp >= 41   & temp  <= 50   ~  temp_weights$temp_41_50[.x],
          temp >= 51   & temp  <= 60   ~  temp_weights$temp_51_60[.x],
          temp >= 61   & temp  <= 70   ~  temp_weights$temp_61_70[.x],
          temp >= 71   & temp  <= 80   ~  temp_weights$temp_71_80[.x],
          temp >= 81   & temp  <= 90   ~  temp_weights$temp_81_90[.x],
          temp >= 91   & temp  <= 100  ~  temp_weights$temp_91_100[.x],
          TRUE & !is.na(temp)          ~  temp_weights$temp_other[.x]
        )
      ) %>% set_names(paste0("temp_wt_", .x))
    ) 

head(test_tibble2)

所以我正在寻找的是一个从权重数据框构造 case_when 参数的函数。

非常模仿 OP:

windows <- 
  str_extract_all(names(temp_weights), "\d+") %>% 
  modify(as.integer) %>% 
  modify_if(negate(length), ~ c(-Inf, Inf)) %>% 
  set_names(names(temp_weights))

temp <- test_tibble$temp

res <-
  map_dfc(
    seq_len(nrow(temp_weights)), 
    ~ {
      row <- .
      rlang::eval_tidy(expr(case_when(
        !!! imap(
          windows, 
          ~ expr(
            between(temp, !! .x[1], !! .x[2]) ~ !! temp_weights[[.y]][row]
          )
        )
      )))
    }
  ) %>% 
  set_names(paste0("temp_wt_", seq_along(.)))

all.equal(res, test_tibble2)
#> [1] TRUE 

效率稍微高一点(不为每个权重组合重复case_when):

res2 <- 
  rlang::eval_tidy(expr(case_when(
    !!! imap(
      windows, 
      ~ expr(
        between(temp, !! .x[1], !! .x[2]) ~ list(!! temp_weights[[.y]])
      )
    )
  ))) %>% 
  do.call(what = rbind) %>% 
  as_tibble() %>% 
  set_names(paste0("temp_wt_", seq_along(.)))

all.equal(res2, test_tibble2)
#> [1] TRUE   

这是为了补充 Aurèle 接受的答案。

在这里,我比较了 Aurèle 提出的两个解决方案和使用 data.table 的最终解决方案之间的效率,后者还提供了保留 NA 的选项。

suppressPackageStartupMessages(library(rlang))
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tictoc))

create_temp_weights <- function(
  from = 31,
  to = 100,
  by = 10,
  weights = exprs(between(., 31, 40) ~ c(0, 0.2),
                  between(., 41, 50) ~ c(0.5, 0.8),
                  between(., 51, 90) ~ c(0.8, 1),
                  between(., 91, 100) ~ c(0.2, 0.8),
                  TRUE ~ c(-0.1, 0))
) {

  # use 999 to map other temperatures to last case
  map(c(seq(from, to, by), 999), ~ case_when(!!!weights)) %>%
    set_names(c(map_chr(seq(from, to, by),
                        ~ str_c("temp_", ., "_", . + by - 1)), "temp_other")) %>%
    cross_df(.)

}

temp_weights <- create_temp_weights()

test_tibble <- tibble(temp = rnorm(1000000, 50, 15))

test_tibble2 <- map_dfc(.x = seq_len(nrow(temp_weights)),
                        ~ transmute(
                          test_tibble,
                          temp =
                            case_when(
                              temp >= 31   & temp  <= 40   ~  temp_weights$temp_31_40[.x],
                              temp >= 41   & temp  <= 50   ~  temp_weights$temp_41_50[.x],
                              temp >= 51   & temp  <= 60   ~  temp_weights$temp_51_60[.x],
                              temp >= 61   & temp  <= 70   ~  temp_weights$temp_61_70[.x],
                              temp >= 71   & temp  <= 80   ~  temp_weights$temp_71_80[.x],
                              temp >= 81   & temp  <= 90   ~  temp_weights$temp_81_90[.x],
                              temp >= 91   & temp  <= 100  ~  temp_weights$temp_91_100[.x],
                              TRUE & !is.na(temp)          ~  temp_weights$temp_other[.x]
                            )
                        ) %>% set_names(paste0("temp_wt_", .x))
) 

windows <- 
  str_extract_all(names(temp_weights), "\d+") %>% 
  modify(as.integer) %>% 
  modify_if(negate(length), ~ c(-Inf, Inf)) %>% 
  set_names(names(temp_weights))

解决方案 #1

temp <- test_tibble$temp

tic()
res <-
  map_dfc(
    seq_len(nrow(temp_weights)), 
    ~ {
      row <- .
      rlang::eval_tidy(expr(case_when(
        !!! imap(
          windows, 
          ~ expr(
            between(temp, !! .x[1], !! .x[2]) ~ !! temp_weights[[.y]][row]
          )
        )
      )))
    }
  ) %>% 
  set_names(paste0("temp_wt_", seq_along(.)))
toc()
#> 65.18 sec elapsed

all.equal(res, test_tibble2)
#> [1] TRUE

解决方案 #2

tic()
res2 <- 
  rlang::eval_tidy(expr(case_when(
    !!! imap(
      windows, 
      ~ expr(
        between(temp, !! .x[1], !! .x[2]) ~ list(!! temp_weights[[.y]])
      )
    )
  ))) %>% 
  do.call(what = rbind) %>% 
  as_tibble() %>% 
  set_names(paste0("temp_wt_", seq_along(.)))
#> Warning: `as_tibble.matrix()` requires a matrix with column names or a `.name_repair` argument. Using compatibility `.name_repair`.
#> This warning is displayed once per session.
toc()
#> 2.76 sec elapsed

all.equal(res2, test_tibble2)
#> [1] TRUE

解决方案 #3 使用 data.table

tic()
res3 <-
  rlang::eval_tidy(expr(case_when(
    !!! imap(
      windows,
      ~ expr(
        between(temp, !! .x[1], !! .x[2]) ~ list(!! temp_weights[[.y]])
      )
    )
  ))) %>%
  data.table::transpose(., fill = NA) %>%
  set_names(paste0("temp_wt_", seq_along(.))) %>%
  as_tibble()
toc()
#> 4.69 sec elapsed

all.equal(res3, test_tibble2)
#> [1] TRUE

总之,解决方案 #2 似乎是最快的(2.76 秒),其次是 data.table 解决方案(4.69 秒)。但是,我很欣赏 data.table 解决方案具有 fill 选项来保留 NA。

reprex package (v0.3.0)

于 2019-08-02 创建