purrr:按 (nest) 和 bootstrap 分组
purrr: Group by (nest) and bootstrap
我计算了 mtcars
数据集中 mpg
变量的 bootstrap 个样本的平均值。我的代码看起来像这样(请让我知道是否有 "better practice" 可以做到这一点。):
mean_mpg <- function(x) {
rsample::analysis(x) %>%
pull(mpg) %>%
mean()
}
mtcars2 <- rsample::bootstraps(mtcars) %>%
mutate(mean_mpg = purrr::map(splits, mean_mpg)) %>%
tidyr::unnest(mean_mpg) %>%
select(-splits)
但是,现在我想对分组数据集执行相同的操作。例如:
mtcars %>%
group_by(am)
# now calculate boostrap means of `mpg` for each `am` group
最好的方法是什么?
我想我会 nest()
这样做,而不是 group_by()
。
这里是一个略微修改的版本,说明如何为整个数据集的每个 bootstrap 重采样找到平均值 mpg
。
library(rsample)
library(tidyverse)
bootstraps(mtcars) %>%
mutate(mpg = map(splits, ~ analysis(.) %>% pull(mpg)),
mean_mpg = map_dbl(mpg, mean))
#> # Bootstrap sampling
#> # A tibble: 25 x 4
#> splits id mpg mean_mpg
#> * <list> <chr> <list> <dbl>
#> 1 <split [32/10]> Bootstrap01 <dbl [32]> 18.8
#> 2 <split [32/13]> Bootstrap02 <dbl [32]> 20.4
#> 3 <split [32/9]> Bootstrap03 <dbl [32]> 21.1
#> 4 <split [32/12]> Bootstrap04 <dbl [32]> 19.4
#> 5 <split [32/10]> Bootstrap05 <dbl [32]> 19.8
#> 6 <split [32/11]> Bootstrap06 <dbl [32]> 20.1
#> 7 <split [32/13]> Bootstrap07 <dbl [32]> 19.1
#> 8 <split [32/11]> Bootstrap08 <dbl [32]> 18.7
#> 9 <split [32/13]> Bootstrap09 <dbl [32]> 19.3
#> 10 <split [32/13]> Bootstrap10 <dbl [32]> 20.9
#> # … with 15 more rows
下面是我将如何为 am
的每个值创建 bootstrap 重采样,然后为这些重采样找到 mpg
的平均值。
mtcars %>%
nest(-am) %>%
mutate(nested_boot = map(data, bootstraps)) %>%
select(-data) %>%
unnest(nested_boot) %>%
mutate(mpg = map(splits, ~ analysis(.) %>% pull(mpg)),
mean_mpg = map_dbl(mpg, mean))
#> # A tibble: 50 x 5
#> am splits id mpg mean_mpg
#> <dbl> <list> <chr> <list> <dbl>
#> 1 1 <split [13/4]> Bootstrap01 <dbl [13]> 21.9
#> 2 1 <split [13/4]> Bootstrap02 <dbl [13]> 24.0
#> 3 1 <split [13/5]> Bootstrap03 <dbl [13]> 24.8
#> 4 1 <split [13/5]> Bootstrap04 <dbl [13]> 25.9
#> 5 1 <split [13/3]> Bootstrap05 <dbl [13]> 24.0
#> 6 1 <split [13/5]> Bootstrap06 <dbl [13]> 22.1
#> 7 1 <split [13/4]> Bootstrap07 <dbl [13]> 24.3
#> 8 1 <split [13/4]> Bootstrap08 <dbl [13]> 25.0
#> 9 1 <split [13/5]> Bootstrap09 <dbl [13]> 22.7
#> 10 1 <split [13/6]> Bootstrap10 <dbl [13]> 23.3
#> # … with 40 more rows
由 reprex package (v0.3.0)
于 2020-05-26 创建
我计算了 mtcars
数据集中 mpg
变量的 bootstrap 个样本的平均值。我的代码看起来像这样(请让我知道是否有 "better practice" 可以做到这一点。):
mean_mpg <- function(x) {
rsample::analysis(x) %>%
pull(mpg) %>%
mean()
}
mtcars2 <- rsample::bootstraps(mtcars) %>%
mutate(mean_mpg = purrr::map(splits, mean_mpg)) %>%
tidyr::unnest(mean_mpg) %>%
select(-splits)
但是,现在我想对分组数据集执行相同的操作。例如:
mtcars %>%
group_by(am)
# now calculate boostrap means of `mpg` for each `am` group
最好的方法是什么?
我想我会 nest()
这样做,而不是 group_by()
。
这里是一个略微修改的版本,说明如何为整个数据集的每个 bootstrap 重采样找到平均值 mpg
。
library(rsample)
library(tidyverse)
bootstraps(mtcars) %>%
mutate(mpg = map(splits, ~ analysis(.) %>% pull(mpg)),
mean_mpg = map_dbl(mpg, mean))
#> # Bootstrap sampling
#> # A tibble: 25 x 4
#> splits id mpg mean_mpg
#> * <list> <chr> <list> <dbl>
#> 1 <split [32/10]> Bootstrap01 <dbl [32]> 18.8
#> 2 <split [32/13]> Bootstrap02 <dbl [32]> 20.4
#> 3 <split [32/9]> Bootstrap03 <dbl [32]> 21.1
#> 4 <split [32/12]> Bootstrap04 <dbl [32]> 19.4
#> 5 <split [32/10]> Bootstrap05 <dbl [32]> 19.8
#> 6 <split [32/11]> Bootstrap06 <dbl [32]> 20.1
#> 7 <split [32/13]> Bootstrap07 <dbl [32]> 19.1
#> 8 <split [32/11]> Bootstrap08 <dbl [32]> 18.7
#> 9 <split [32/13]> Bootstrap09 <dbl [32]> 19.3
#> 10 <split [32/13]> Bootstrap10 <dbl [32]> 20.9
#> # … with 15 more rows
下面是我将如何为 am
的每个值创建 bootstrap 重采样,然后为这些重采样找到 mpg
的平均值。
mtcars %>%
nest(-am) %>%
mutate(nested_boot = map(data, bootstraps)) %>%
select(-data) %>%
unnest(nested_boot) %>%
mutate(mpg = map(splits, ~ analysis(.) %>% pull(mpg)),
mean_mpg = map_dbl(mpg, mean))
#> # A tibble: 50 x 5
#> am splits id mpg mean_mpg
#> <dbl> <list> <chr> <list> <dbl>
#> 1 1 <split [13/4]> Bootstrap01 <dbl [13]> 21.9
#> 2 1 <split [13/4]> Bootstrap02 <dbl [13]> 24.0
#> 3 1 <split [13/5]> Bootstrap03 <dbl [13]> 24.8
#> 4 1 <split [13/5]> Bootstrap04 <dbl [13]> 25.9
#> 5 1 <split [13/3]> Bootstrap05 <dbl [13]> 24.0
#> 6 1 <split [13/5]> Bootstrap06 <dbl [13]> 22.1
#> 7 1 <split [13/4]> Bootstrap07 <dbl [13]> 24.3
#> 8 1 <split [13/4]> Bootstrap08 <dbl [13]> 25.0
#> 9 1 <split [13/5]> Bootstrap09 <dbl [13]> 22.7
#> 10 1 <split [13/6]> Bootstrap10 <dbl [13]> 23.3
#> # … with 40 more rows
由 reprex package (v0.3.0)
于 2020-05-26 创建