在 sparklyr 中聚合标准偏差和计算非 NA

Aggregating the standard deviation and counting non-NAs in sparklyr

我有一个很大的 data.frame 并且我一直在使用 summariseacross 汇总众多变量的汇总统计数据。由于 data.frame 的大小,我不得不开始处理 sparklyr 中的数据。

由于sparklyr不支持across我现在用的是summarise_each。这工作正常,除了 sparklyr 中的 summarise_each 似乎不支持 sdsum(!is.na(.))

下面是一个示例数据集以及我通常如何处理它,使用 dplyr:

test <- data.frame(ID = c("Group1","Group1",'Group1','Group1','Group1','Group1','Group1',
                          "Group2","Group2","Group2",'Group2','Group2','Group2',"Group2",
                          "Group3","Group3","Group3"),
                      Value1 = c(-100,-10,-5,-5,-5,1,2,1,2,3,4,4,4,4,1,2,3),
                      Value2 = c(50,100,10,-5,3,1,2,2,2,3,4,4,4,4,1,2,3))
test %>% 
  group_by %>%
  summarise(across((Value1:Value2), ~sum(!is.na(.), na.rm = TRUE), .names = "{col}_count"),
            across((Value1:Value2), ~min(., na.rm = TRUE), .names = "{col}_min"),
            across((Value1:Value2), ~max(., na.rm = TRUE), .names = "{col}_max"),
            across((Value1:Value2), ~mean(., na.rm = TRUE), .names = "{col}_mean"),
            across((Value1:Value2), ~sd(., na.rm = TRUE), .names = "{col}_sd"))

# A tibble: 1 x 10
  Value1_count Value2_count Value1_min Value2_min Value1_max Value2_max Value1_mean Value2_mean Value1_sd Value2_sd
         <int>        <int>      <dbl>      <dbl>      <dbl>      <dbl>       <dbl>       <dbl>     <dbl>     <dbl>
1           17           17       -100         -5          4        100       -5.53        11.2      24.7      25.8

我也已经能够使用 summarise_each 成功获得相同的答案,如下所示:

test %>% 
  group_by(ID) %>%
  summarise_each(funs(min = min(., na.rm = TRUE),
                      max = max(., na.rm = TRUE),
                      mean = mean(., na.rm = TRUE), 
                      sum = sum(., na.rm = TRUE),
                      sd = sd(., na.rm = TRUE)))

  ID     Value1_min Value2_min Value1_max Value2_max Value1_mean Value2_mean Value1_sum Value2_sum
  <fct>       <dbl>      <dbl>      <dbl>      <dbl>       <dbl>       <dbl>      <dbl>      <dbl>
1 Group1       -100         -5          2        100      -17.4        23          -122        161
2 Group2          1          2          4          4        3.14        3.29         22         23
3 Group3          1          1          3          3        2           2             6          6

当使用 sparklyr 时,我已经成功计算出 minmaxmeansum,如下所示:

sc <- spark_connect(master = "local", version = "2.4.3")
test <- spark_read_csv(sc = sc, path = "C:\path\test space.csv")

test %>% 
  group_by(ID) %>%
  summarise_each(funs(min = min(., na.rm = TRUE),
                      max = max(., na.rm = TRUE),
                      mean = mean(., na.rm = TRUE), 
                      sum = sum(., na.rm = TRUE)))
# Source: spark<?> [?? x 9]
  ID     Value1_min Value_2_min Value1_max Value_2_max Value1_mean Value_2_mean Value1_sum Value_2_sum
  <chr>       <int>       <int>      <int>       <int>       <dbl>        <dbl>      <dbl>       <dbl>
1 Group2          1           2          4           4        3.14         3.29         22          23
2 Group3          1           1          3           3        2            2             6           6
3 Group1       -100          -5          2         100      -17.4         23          -122         161

但是我在尝试获取 sdsum(!is.na(.)) 时收到错误消息 下面是我收到的代码和错误消息。是否有任何解决方法可以帮助汇总这些值?

test %>% 
  group_by(ID) %>%
  summarise_each(funs(min = min(., na.rm = TRUE),
                      max = max(., na.rm = TRUE),
                      mean = mean(., na.rm = TRUE), 
                      sum = sum(., na.rm = TRUE),
                      sd = sd(., na.rm = TRUE)))

Error: org.apache.spark.sql.catalyst.parser.ParseException: 
mismatched input 'AS' expecting ')'(line 1, pos 298)

== SQL ==
SELECT `ID`, MIN(`Value1`) AS `Value1_min`, MIN(`Value_2`) AS `Value_2_min`, MAX(`Value1`) AS `Value1_max`, MAX(`Value_2`) AS `Value_2_max`, AVG(`Value1`) AS `Value1_mean`, AVG(`Value_2`) AS `Value_2_mean`, SUM(`Value1`) AS `Value1_sum`, SUM(`Value_2`) AS `Value_2_sum`, stddev_samp(`Value1`, TRUE AS `na.rm`) AS `Value1_sd`, stddev_samp(`Value_2`, TRUE AS `na.rm`) AS `Value_2_sd`
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------^^^
FROM `test_space_30172a44_c0aa_4305_9a5e_d45fa77ba0b9`
GROUP BY `ID`

    at org.apache.spark.sql.catalyst.parser.ParseException.withCommand(ParseDriver.scala:241)
    at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse(ParseDriver.scala:117)
    at org.apache.spark.sql.execution.SparkSqlParser.parse(SparkSqlParser.scala:48)
    at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parsePlan(ParseDriver.scala:69)
    at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:642)
    at sun.reflect.GeneratedMethodAccessor66.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at sparklyr.Invoke.invoke(invoke.scala:147)
    at sparklyr.StreamHandler.handleMethodCall(stream.scala:136)
    at sparklyr.StreamHandler.read(stream.scala:61)
    at sparklyr.BackendHandler$$anonfun$channelRead0.apply$mcV$sp(handler.scala:58)
    at scala.util.control.Breaks.breakable(Breaks.scala:38)
    at sparklyr.BackendHandler.channelRead0(handler.scala:38)
    at sparklyr.BackendHandler.channelRead0(handler.scala:14)
    at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:340)
    at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:102)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:340)
    at io.netty.handler.codec.ByteToMessageDecoder.fireChannelRead(ByteToMessageDecoder.java:310)
    at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:284)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:340)
    at io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1359)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
    at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:935)
    at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:138)
    at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:645)
    at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:580)
    at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:497)
    at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:459)
    at io.netty.util.concurrent.SingleThreadEventExecutor.run(SingleThreadEventExecutor.java:858)
    at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)
    at java.lang.Thread.run(Thread.java:748)
In addition: Warning messages:
1: Named arguments ignored for SQL stddev_samp 
2: Named arguments ignored for SQL stddev_samp 

问题是 na.rm 参数。 Spark 的 stddev_samp 函数没有这样的参数,sparklyr 似乎无法处理它。

SQL 中的缺失值始终会被删除,因此您无需指定 na.rm

test_spark %>% 
  group_by(ID) %>%
  summarise_each(funs(min = min(.),
                      max = max(.),
                      mean = mean(.), 
                      sum = sum(.),
                      sd = sd(.)))
#> # Source: spark<?> [?? x 11]
#>   ID     Value1_min Value2_min Value1_max Value2_max Value1_mean Value2_mean
#>   <chr>       <dbl>      <dbl>      <dbl>      <dbl>       <dbl>       <dbl>
#> 1 Group2          1          2          4          4        3.14        3.29
#> 2 Group1       -100         -5          2        100      -17.4        23   
#> 3 Group3          1          1          3          3        2           2   
#>   Value1_sum Value2_sum Value1_sd Value2_sd
#>        <dbl>      <dbl>     <dbl>     <dbl>
#> 1         22         23      1.21     0.951
#> 2       -122        161     36.6     38.6  
#> 3          6          6      1        1  

这看起来像是 summarise 特有的错误,因为 sdna.rm works finemutate

test_spark %>% 
  group_by(ID) %>%
  mutate_each(funs(sd = sd(., na.rm = TRUE))) 

对于sum(!is.na(.)),只要写成sum(ifelse(is.na(.), 0, 1))即可。