sparklyr:如何跨组获取平衡样本
sparklyr: how to take a balanced sample across groups
我想从 sparklyr
.
中我的 Spark DataFrame 的每个 classes 中采样 n 行
我知道 dplyr::sample_n
函数不能用于此 () 所以我使用了 sparklyr::sdf_sample()
函数。这样做的问题是我无法按组进行抽样,即从每个 class 中获得 10 个观察值,我只能指定要抽样的整个数据集的分数。
我有一个解决方法,可以在循环中对每个组单独使用 sdf_sample()
,但是由于该函数没有 return 确切的样本大小,所以这仍然不理想。
解决方法的 R 代码:
library(sparklyr)
library(dplyr)
sc <- spark_connect(master = "local", version = "2.3")
# copy iris to our spark cluster
iris_tbl <- copy_to(sc, iris, overwrite = TRUE)
# get class counts
class_counts <- iris_tbl %>% count(Species) %>%
collect()
# Species n
# <chr> <dbl>
#1 versicolor 50
#2 virginica 50
#3 setosa 50
# we want to sample n = 10 points from each class
n <- 10
sampled_iris <- data.frame(stringsAsFactors = F)
for( i in seq_along(class_counts$Species)){
my_frac <- n / class_counts[[i, 'n']]
my_class <- class_counts[[i, 'Species']]
tmp <- iris_tbl %>%
filter(Species == my_class) %>%
sdf_sample(fraction = my_frac) %>%
collect()
sampled_iris <- bind_rows(sampled_iris, tmp)
}
我们没有从每个 class:
中得到恰好 10 个样本
# new counts
sampled_iris %>% count(Species)
#Species n
# <chr> <int>
#1 setosa 7
#2 versicolor 9
#3 virginica 6
我想知道是否有更好的方法使用 sparklyr 获得跨组的平衡样本?或者甚至使用 sql 查询,我可以使用 DBI::dbGetQuery()
?
直接将其传递给集群
I can't sample by group
只要分组列是字符串(这是 sparklyr
类型映射的限制),那部分就可以使用 DataFrameStatFunctions.sampleBy
:
轻松处理
spark_dataframe(iris_tbl) %>%
sparklyr::invoke("stat") %>%
sparklyr::invoke(
"sampleBy",
"Species",
fractions=as.environment(list(
"setosa"=0.2,
"versicolor"=0.2,
"virginica"=0.2
)),
seed=1L
) %>% sparklyr::sdf_register()
然而,没有分布式和可扩展的方法会给你 "exact sample size"。可以使用 hack,例如:
iris_tbl %>%
group_by(Species) %>%
mutate(rand = rand()) %>%
arrange(rand, .by_group=TRUE) %>%
filter(row_number() <= 10) %>%
select(-rand)
但是这种依赖于 window 函数的方法对倾斜的数据分布高度敏感,并且通常不能很好地扩展。
如果样本很小,你可以更进一步,但是先过采样(使用第一种方法)然后取精确样本(使用第二种方法),但如果你的数据大到足以用 Spark 处理, 小的波动应该不重要。
我想从 sparklyr
.
我知道 dplyr::sample_n
函数不能用于此 (sparklyr::sdf_sample()
函数。这样做的问题是我无法按组进行抽样,即从每个 class 中获得 10 个观察值,我只能指定要抽样的整个数据集的分数。
我有一个解决方法,可以在循环中对每个组单独使用 sdf_sample()
,但是由于该函数没有 return 确切的样本大小,所以这仍然不理想。
解决方法的 R 代码:
library(sparklyr)
library(dplyr)
sc <- spark_connect(master = "local", version = "2.3")
# copy iris to our spark cluster
iris_tbl <- copy_to(sc, iris, overwrite = TRUE)
# get class counts
class_counts <- iris_tbl %>% count(Species) %>%
collect()
# Species n
# <chr> <dbl>
#1 versicolor 50
#2 virginica 50
#3 setosa 50
# we want to sample n = 10 points from each class
n <- 10
sampled_iris <- data.frame(stringsAsFactors = F)
for( i in seq_along(class_counts$Species)){
my_frac <- n / class_counts[[i, 'n']]
my_class <- class_counts[[i, 'Species']]
tmp <- iris_tbl %>%
filter(Species == my_class) %>%
sdf_sample(fraction = my_frac) %>%
collect()
sampled_iris <- bind_rows(sampled_iris, tmp)
}
我们没有从每个 class:
中得到恰好 10 个样本# new counts
sampled_iris %>% count(Species)
#Species n
# <chr> <int>
#1 setosa 7
#2 versicolor 9
#3 virginica 6
我想知道是否有更好的方法使用 sparklyr 获得跨组的平衡样本?或者甚至使用 sql 查询,我可以使用 DBI::dbGetQuery()
?
I can't sample by group
只要分组列是字符串(这是 sparklyr
类型映射的限制),那部分就可以使用 DataFrameStatFunctions.sampleBy
:
spark_dataframe(iris_tbl) %>%
sparklyr::invoke("stat") %>%
sparklyr::invoke(
"sampleBy",
"Species",
fractions=as.environment(list(
"setosa"=0.2,
"versicolor"=0.2,
"virginica"=0.2
)),
seed=1L
) %>% sparklyr::sdf_register()
然而,没有分布式和可扩展的方法会给你 "exact sample size"。可以使用 hack,例如:
iris_tbl %>%
group_by(Species) %>%
mutate(rand = rand()) %>%
arrange(rand, .by_group=TRUE) %>%
filter(row_number() <= 10) %>%
select(-rand)
但是这种依赖于 window 函数的方法对倾斜的数据分布高度敏感,并且通常不能很好地扩展。
如果样本很小,你可以更进一步,但是先过采样(使用第一种方法)然后取精确样本(使用第二种方法),但如果你的数据大到足以用 Spark 处理, 小的波动应该不重要。