如何使用变量作为 R 中 dplyr::slice_max() 函数的参数

how to use a variable as a parameter of the dplyr::slice_max() function in R

给定一个 data.frame:

tibble(group = c(rep("A", 4), rep("B", 4), rep("C", 4)),
        value = runif(12),
        n_slice = c(rep(2, 4), rep(1, 4), rep(3, 4)) )

# A tibble: 12 x 3
   group  value n_slice
   <chr>  <dbl>   <dbl>
 1 A     0.853        2
 2 A     0.726        2
 3 A     0.783        2
 4 A     0.0426       2
 5 B     0.320        1
 6 B     0.683        1
 7 B     0.732        1
 8 B     0.319        1
 9 C     0.118        3
10 C     0.0259       3
11 C     0.818        3
12 C     0.635        3

我想按组切片,每组中的行数不同

我尝试了下面的代码,但我收到通知说“n”必须是一个常量

re %>% 
   group_by(group) %>% 
   slice_max(value, n = n_slice)

Erro: `n` must be a constant in `slice_max()`.

预期输出:

  group value n_slice
  <chr> <dbl>   <dbl>
1 A     0.853       2
2 A     0.783       2
3 B     0.732       1
4 C     0.818       3
5 C     0.635       3
6 C     0.118       3

在这种情况下,一个选项是group_modify

library(dplyr)
re %>% 
   group_by(group) %>% 
   group_modify(~ .x %>%
          slice_max(value, n = first(.x$n_slice))) %>%
   ungroup

-输出

# A tibble: 6 × 3
  group value n_slice
  <chr> <dbl>   <dbl>
1 A     0.931       2
2 A     0.931       2
3 B     0.722       1
4 C     0.591       3
5 C     0.519       3
6 C     0.494       3

或者另一种选择是 summarise 使用 cur_data() 然后 unnest

library(tidyr)
re %>%
    group_by(group) %>%
    summarise(out = list(cur_data() %>% 
        slice_max(value, n = first(n_slice)))) %>% 
    unnest(out)

-输出

# A tibble: 6 × 3
  group value n_slice
  <chr> <dbl>   <dbl>
1 A     0.931       2
2 A     0.931       2
3 B     0.722       1
4 C     0.591       3
5 C     0.519       3
6 C     0.494       3

我认为 slice_max 不支持这一点,也许是因为不难想象 n_slice 在组内不是恒定的数据(该操作是模棱两可的)。尝试使用 filter:

set.seed(42)
re <- tibble(group = c(rep("A", 4), rep("B", 4), rep("C", 4)),
        value = runif(12),
        n_slice = c(rep(2, 4), rep(1, 4), rep(3, 4)) )
re
# # A tibble: 12 x 3
#    group value n_slice
#    <chr> <dbl>   <dbl>
#  1 A     0.915       2
#  2 A     0.937       2
#  3 A     0.286       2
#  4 A     0.830       2
#  5 B     0.642       1
#  6 B     0.519       1
#  7 B     0.737       1
#  8 B     0.135       1
#  9 C     0.657       3
# 10 C     0.705       3
# 11 C     0.458       3
# 12 C     0.719       3

re %>%
  group_by(group) %>%
  filter(rank(-value) <= n_slice[1])
# # A tibble: 6 x 3
# # Groups:   group [3]
#   group value n_slice
#   <chr> <dbl>   <dbl>
# 1 A     0.915       2
# 2 A     0.937       2
# 3 B     0.737       1
# 4 C     0.657       3
# 5 C     0.705       3
# 6 C     0.719       3

备注:

  • 由于数据中可能存在 关系 ,使用 rank(., ties.method = ...)(参见 ?rank)或dplyr::dense_rank.

  • 如果您要切片的列不支持求反(例如,DatePOSIXt),那么您可以将 rank(-value) 更改为 n() - rank(value) + 1L <= n_slice[1] 以获得相同的效果(或更简单地说 n() - rank(value) < n_slice[1])。另一个选项是 rank(desc(value)),感谢@IceCreamToucan 的建议。


如果性能是个问题(你有更多的行),那么每组只使用 filter 是目前最快的答案:-)

bench::mark(
  akrun1 = re %>% group_by(group) %>% group_modify(~ .x %>% slice_max(value, n = first(.x$n_slice))) %>% ungroup,
  akrun2 = re %>% group_by(group) %>% summarise(out = list(cur_data() %>% slice_max(value, n = first(n_slice)))) %>% tidyr::unnest(out),
  r2evans = re %>% group_by(group) %>% filter(rank(-value) <= n_slice[1]) %>% ungroup() %>% arrange(group, -value)
)
# # A tibble: 3 x 13
#   expression     min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result     memory       time     gc       
#   <bch:expr> <bch:t> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>     <list>       <list>   <list>   
# 1 akrun1       7.2ms   9.05ms     108.     7.55KB     13.8    47     6      436ms <tibble [~ <Rprofmem[,~ <bch:tm~ <tibble ~
# 2 akrun2      9.56ms  11.32ms      87.0      20KB     12.1    36     5      414ms <tibble [~ <Rprofmem[,~ <bch:tm~ <tibble ~
# 3 r2evans     4.21ms   5.27ms     186.     4.87KB     14.3    78     6      420ms <tibble [~ <Rprofmem[,~ <bch:tm~ <tibble ~

(在我的末尾添加 arrange(.) 以准确模拟输出。)

如果您没有更多的行,或者即使有,坦率地说,可读性更为重要。所有的答案都会产生相同的结果,所以无论哪个对用户来说更有意义(以及 future self 当你回顾 6 个月的旧代码时),运行时的一点惩罚通常是值得的.

并不是说这个问题需要另一种方法来复制 slice_max,只是为了好玩,您可以使用 arrange 后接 slice

library(dplyr, warn.conflicts = F)
set.seed(42)
re <- tibble(group = c(rep("A", 4), rep("B", 4), rep("C", 4)),
        value = runif(12),
        n_slice = c(rep(2, 4), rep(1, 4), rep(3, 4)) )

re %>% 
  group_by(group) %>% 
  arrange(desc(value)) %>% 
  slice(seq(first(n_slice))) %>% 
  ungroup
#> # A tibble: 6 × 3
#>   group value n_slice
#>   <chr> <dbl>   <dbl>
#> 1 A     0.937       2
#> 2 A     0.915       2
#> 3 B     0.737       1
#> 4 C     0.719       3
#> 5 C     0.705       3
#> 6 C     0.657       3

reprex package (v2.0.1)

于 2021-12-17 创建

令人惊讶的是,这似乎快了一点

library(bench) 
library(dplyr, warn.conflicts = F)

set.seed(42)
n <- 1e5
re <- tibble(group = c(rep("A", n), rep("B", n), rep("C", n)),
        value = runif(n*3),
        n_slice = c(rep(sample(n, 1), n), rep(sample(n, 1), n), rep(sample(n, 1), n)) )


bench::mark(
  akrun1 = re %>% group_by(group) %>% group_modify(~ .x %>% slice_max(value, n = first(.x$n_slice))) %>% ungroup,
  akrun2 = re %>% group_by(group) %>% summarise(out = list(cur_data() %>% slice_max(value, n = first(n_slice)))) %>% tidyr::unnest(out),
  r2evans = re %>% group_by(group) %>% filter(rank(-value) <= n_slice[1]) %>% ungroup() %>% arrange(group, -value),
  arrange = re %>% group_by(group) %>% arrange(desc(value)) %>% slice(seq(first(n_slice))) %>% ungroup %>% arrange(group, -value)
)
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
#> # A tibble: 4 × 6
#>   expression      min   median `itr/sec` mem_alloc `gc/sec`
#>   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
#> 1 akrun1      167.7ms  169.9ms      5.71      31MB     5.71
#> 2 akrun2      166.2ms  172.4ms      5.82    29.5MB     9.70
#> 3 r2evans       173ms  175.2ms      5.67    31.8MB     3.78
#> 4 arrange      66.7ms   75.2ms     11.9     29.5MB    17.9

reprex package (v2.0.1)

于 2021-12-17 创建