如何在 PySpark 中找到数组数组的平均值

How to find average of array of arrays in PySpark

我有一个 PySpark 数据框,其中一列(比如 B)是一个数组数组。以下是 PySpark 数据框:

+---+-----------------------------+---+
|A  |B                            |C  |
+---+-----------------------------+---+
|a  |[[5.0], [25.0, 25.0], [40.0]]|c  |
|a  |[[5.0], [20.0, 80.0]]        |d  |
|a  |[[5.0], [25.0, 75.0]]        |e  |
|b  |[[5.0], [25.0, 75.0]]        |f  |
|b  |[[5.0], [12.0, 88.0]]        |g  |
+---+-----------------------------+---+

我想找出每一行的元素数量和所有元素的平均值(作为单独的列)。

下面是预期的输出:

+---+-----------------------------+---+---+------+
|A  |B                            |C  |Num|   Avg|
+---+-----------------------------+---+---+------+
|a  |[[5.0], [25.0, 25.0], [40.0]]|c  |4  | 23.75|
|a  |[[5.0], [20.0, 80.0]]        |d  |3  | 35.00|
|a  |[[5.0], [25.0, 75.0]]        |e  |3  | 35.00|
|b  |[[5.0], [25.0, 75.0]]        |f  |3  | 35.00|
|b  |[[5.0], [12.0, 88.0]]        |g  |3  | 35.00|
+---+-----------------------------+---+---+------+

在 PySpark 中查找数组数组(每行)中所有元素的平均值的有效方法是什么?

目前,我正在使用 udf 来执行这些操作。下面是我目前的代码:

from pyspark.sql import functions as F
import pyspark.sql.types as T
from pyspark.sql import *
from pyspark.sql.types import DecimalType
from pyspark.sql.functions import udf
import numpy as np

#UDF to find number of elements
def len_array_of_arrays(anomaly_in_issue_group_col):
    return sum([len(array_element) for array_element in anomaly_in_issue_group_col])

udf_len_array_of_arrays = F.udf( len_array_of_arrays , T.IntegerType() )

#UDF to find average of all elements
def avg_array_of_arrays(anomaly_in_issue_group_col):
    return np.mean( [ element for array_element in anomaly_in_issue_group_col for element in array_element] )

udf_avg_array_of_arrays = F.udf( avg_array_of_arrays , T.DecimalType() )

df.withColumn("Num", udf_len_array_of_arrays(F.col("B"))).withColumn(
    "Avg", udf_avg_array_of_arrays(F.col("B"))
).show(20, False)

用于查找每行中元素数量的 udf 有效。但是,用于查找平均值的 udf 会引发以下错误:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-176-3253feca2963> in <module>()
      1 #df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).show(20, False)
----> 2 df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).withColumn("Avg" ,  udf_avg_array_of_arrays(F.col("B")) ).show(20, False)

/usr/lib/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
    378             print(self._jdf.showString(n, 20, vertical))
    379         else:
--> 380             print(self._jdf.showString(n, int(truncate), vertical))
    381 
    382     def __repr__(self):

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

对于 spark 2.4+,使用 flatten + aggregate:

from pyspark.sql.functions import expr

df.withColumn("Avg", expr("""
    aggregate(
        flatten(B)
      , (double(0) as total, int(0) as cnt)
      , (x,y) -> (x.total+y, x.cnt+1)
      , z -> round(z.total/z.cnt,2)
    ) 
 """)).show()
+-----------------------------+---+-----+
|B                            |C  |Avg  |
+-----------------------------+---+-----+
|[[5.0], [25.0, 25.0], [40.0]]|c  |23.75|
|[[5.0], [25.0, 80.0]]        |d  |36.67|
|[[5.0], [25.0, 75.0]]        |e  |35.0 |
+-----------------------------+---+-----+

从 Spark 1.4 开始:

explode() 包含数组的列,与嵌套级别一样多。使用 monotonically_increasing_id() 创建额外的分组键以防止合并重复行:

from pyspark.sql.functions import explode, sum, lit, avg, monotonically_increasing_id

df = spark.createDataFrame(
    (("a", [[1], [2, 3], [4]], "foo"),
     ("a", [[5], [6, 0], [4]], "foo"),
     ("a", [[5], [6, 0], [4]], "foo"),  # DUPE!
     ("b", [[2, 3], [4]], "foo")),
    schema=("category", "arrays", "foo"))

df2 = (df.withColumn("id", monotonically_increasing_id())
       .withColumn("subarray", explode("arrays"))
       .withColumn("subarray", explode("subarray"))  # unnest another level
       .groupBy("category", "arrays", "foo", "id")
       .agg(sum(lit(1)).alias("number_of_elements"),
            avg("subarray").alias("avg")).drop("id"))
df2.show()
# +--------+------------------+---+------------------+----+  
# |category|            arrays|foo|number_of_elements| avg|
# +--------+------------------+---+------------------+----+
# |       a|[[5], [6, 0], [4]]|foo|                 4|3.75|
# |       b|     [[2, 3], [4]]|foo|                 3| 3.0|
# |       a|[[5], [6, 0], [4]]|foo|                 4|3.75|
# |       a|[[1], [2, 3], [4]]|foo|                 4| 2.5|
# +--------+------------------+---+------------------+----+

Spark 2.4 引入了 24 个处理复杂类型的函数,以及高阶函数(将函数作为参数的函数,如 Python 3 的 functools.reduce)。他们拿走了你在上面看到的样板。如果您使用的是 Spark2.4+,请参阅 .