组内过采样
Oversample within in group
我想过度采样,以便我的数据集中每个组内的二元因变量保持平衡。
所以我的数据是这样的:
library(dplyr)
library(purrr)
library(tidyr)
seed(123)
# example data
(data <- tibble(
country = c("France", "France", "France",
"UK", "UK", "UK", "UK", "UK", "UK"),
YES = c(0, 0, 1,
0, 0, 0, 0, 1, 1),
X = rnorm(9, 0 ,1)
))
# A tibble: 9 x 3
country YES X
<chr> <dbl> <dbl>
1 France 0 -1.12
2 France 0 -0.200
3 France 1 0.781
4 UK 0 0.100
5 UK 0 0.0997
6 UK 0 -0.380
7 UK 0 -0.0160
8 UK 1 -0.0265
9 UK 1 0.860
我正试图通过过度采样在法国和英国实现 YES 的平衡。在法国,我希望有 4 个观察结果,而在英国,我希望有 8 个观察结果,以便一个随机样本看起来像这样):
# A tibble: 12 x 3
country YES X
<chr> <dbl> <dbl>
1 France 0 -1.12
2 France 0 -0.200
3 France 1 0.781
3 France 1 0.781
4 UK 0 0.100
5 UK 0 0.0997
6 UK 0 -0.380
7 UK 0 -0.0160
8 UK 1 -0.0265
9 UK 1 0.860
8 UK 1 -0.0265
8 UK 1 -0.0265
我的方法是这样的:
# oversample 1's within each country
(n_data <- data %>%
group_by(country) %>%
nest(.key = "original") %>%
mutate(os = map(original, ~ group_by(., YES))) %>%
mutate(os = map(os, ~ slice_sample(., replace = TRUE, prop = 1))))
# A tibble: 2 x 3
# Groups: country [2]
country original os
<chr> <list> <list>
1 France <tibble [3 x 2]> <tibble [3 x 2]>
2 UK <tibble [6 x 2]> <tibble [6 x 2]>
Warning message:
`.key` is deprecated
所以在 OS 中尺寸应该是 4 x 2 和 8 x 2。有人知道怎么做吗?
这似乎过于复杂,但每个单独的步骤似乎都清晰有力:
data %>%
count(country, YES) %>%
group_by(country) %>%
## Figure out how many additional rows are needed
mutate(
goal_rows = max(n),
extra_rows = goal_rows - n
) %>%
select(country, YES, extra_rows) %>%
## Keep only the country/YES combinations that need extra rows
filter(extra_rows > 0) %>%
## Join back to original data
left_join(data, by = c("country", "YES")) %>%
group_by(country) %>%
## Randomly keep the appropriate number of rows
mutate(rand = rank(runif(n()))) %>%
filter(rand <= extra_rows) %>%
select(-extra_rows, -rand) %>%
## Combine oversampled rows with original data
bind_rows(data) %>%
arrange(country, YES)
# # A tibble: 12 x 3
# # Groups: country [2]
# country YES X
# <chr> <dbl> <dbl>
# 1 France 0 1.88
# 2 France 0 -0.0793
# 3 France 1 0.812
# 4 France 1 0.812
# 5 UK 0 -1.66
# 6 UK 0 -0.797
# 7 UK 0 0.639
# 8 UK 0 -0.141
# 9 UK 1 -0.207
# 10 UK 1 1.30
# 11 UK 1 -0.207
# 12 UK 1 1.30
我想过度采样,以便我的数据集中每个组内的二元因变量保持平衡。
所以我的数据是这样的:
library(dplyr)
library(purrr)
library(tidyr)
seed(123)
# example data
(data <- tibble(
country = c("France", "France", "France",
"UK", "UK", "UK", "UK", "UK", "UK"),
YES = c(0, 0, 1,
0, 0, 0, 0, 1, 1),
X = rnorm(9, 0 ,1)
))
# A tibble: 9 x 3
country YES X
<chr> <dbl> <dbl>
1 France 0 -1.12
2 France 0 -0.200
3 France 1 0.781
4 UK 0 0.100
5 UK 0 0.0997
6 UK 0 -0.380
7 UK 0 -0.0160
8 UK 1 -0.0265
9 UK 1 0.860
我正试图通过过度采样在法国和英国实现 YES 的平衡。在法国,我希望有 4 个观察结果,而在英国,我希望有 8 个观察结果,以便一个随机样本看起来像这样):
# A tibble: 12 x 3
country YES X
<chr> <dbl> <dbl>
1 France 0 -1.12
2 France 0 -0.200
3 France 1 0.781
3 France 1 0.781
4 UK 0 0.100
5 UK 0 0.0997
6 UK 0 -0.380
7 UK 0 -0.0160
8 UK 1 -0.0265
9 UK 1 0.860
8 UK 1 -0.0265
8 UK 1 -0.0265
我的方法是这样的:
# oversample 1's within each country
(n_data <- data %>%
group_by(country) %>%
nest(.key = "original") %>%
mutate(os = map(original, ~ group_by(., YES))) %>%
mutate(os = map(os, ~ slice_sample(., replace = TRUE, prop = 1))))
# A tibble: 2 x 3
# Groups: country [2]
country original os
<chr> <list> <list>
1 France <tibble [3 x 2]> <tibble [3 x 2]>
2 UK <tibble [6 x 2]> <tibble [6 x 2]>
Warning message:
`.key` is deprecated
所以在 OS 中尺寸应该是 4 x 2 和 8 x 2。有人知道怎么做吗?
这似乎过于复杂,但每个单独的步骤似乎都清晰有力:
data %>%
count(country, YES) %>%
group_by(country) %>%
## Figure out how many additional rows are needed
mutate(
goal_rows = max(n),
extra_rows = goal_rows - n
) %>%
select(country, YES, extra_rows) %>%
## Keep only the country/YES combinations that need extra rows
filter(extra_rows > 0) %>%
## Join back to original data
left_join(data, by = c("country", "YES")) %>%
group_by(country) %>%
## Randomly keep the appropriate number of rows
mutate(rand = rank(runif(n()))) %>%
filter(rand <= extra_rows) %>%
select(-extra_rows, -rand) %>%
## Combine oversampled rows with original data
bind_rows(data) %>%
arrange(country, YES)
# # A tibble: 12 x 3
# # Groups: country [2]
# country YES X
# <chr> <dbl> <dbl>
# 1 France 0 1.88
# 2 France 0 -0.0793
# 3 France 1 0.812
# 4 France 1 0.812
# 5 UK 0 -1.66
# 6 UK 0 -0.797
# 7 UK 0 0.639
# 8 UK 0 -0.141
# 9 UK 1 -0.207
# 10 UK 1 1.30
# 11 UK 1 -0.207
# 12 UK 1 1.30