
Sum of array elements depending on value condition pyspark

我有一个 pyspark 数据框:

id   |   column
1    |  [0.2, 2, 3, 4, 3, 0.5]
2    |  [7, 0.3, 0.3, 8, 2,]

我想创建一个 3 列:


id   |   column               |  column<2 |  column>2   | column=2 
1    |  [0.2, 2, 3, 4, 3, 0.5]|  [0.7]    |  [12]       |  null
2    |  [7, 0.3, 0.3, 8, 2,]  | [0.6]     |  [15]       |  [2]

你能帮帮我吗? 谢谢


import pyspark.sql.functions as F

# using map filter the list and count based on condition
s = (df
     .map(lambda x: [[i for i in x.column if i < 2], 
                     [i for i in x.column if i > 2], 
                     [i for i in x.column if i == 2]])
     .map(lambda x: [Row(round(sum(i), 2)) for i in x]))

# create a dummy id so we can join both data frames
df = df.withColumn('mid', F.monotonically_increasing_id())
s = s.withColumn('mid', F.monotonically_increasing_id())

#simple left join
df = df.join(s, on='mid').drop('mid').show()

| id|              column|col<2| col>2|col=2|
|  0|[0.2, 2.0, 3.0, 4...|[0.7]|[10.0]|[2.0]|
|  1|[7.0, 0.3, 0.3, 8...|[0.6]|[15.0]|[2.0]|

对于 Spark 2.4+,您可以像这样使用 aggregate and filter 高阶函数:

df.withColumn("column<2", expr("aggregate(filter(column, x -> x < 2), 0D, (x, acc) -> acc + x)")) \
  .withColumn("column>2", expr("aggregate(filter(column, x -> x > 2), 0D, (x, acc) -> acc + x)")) \
  .withColumn("column=2", expr("aggregate(filter(column, x -> x == 2), 0D, (x, acc) -> acc + x)")) \


|id |column                        |column<2|column>2|column=2|
|1  |[0.2, 2.0, 3.0, 4.0, 3.0, 0.5]|0.7     |10.0    |2.0     |
|2  |[7.0, 0.3, 0.3, 8.0, 2.0]     |0.6     |15.0    |2.0     |

对于Spark 2.4+,您可以使用aggregate函数一步完成计算:

from pyspark.sql.functions import expr

# I adjusted the 2nd array-item in id=1 from 2.0 to 2.1 so there is no `2.0` when id=1
df = spark.createDataFrame([(1,[0.2, 2.1, 3., 4., 3., 0.5]),(2,[7., 0.3, 0.3, 8., 2.,])],['id','column'])

df.withColumn('data', expr("""

      /* ArrayType argument */
      /* zero: set empty array to initialize acc */
      /* merge: iterate through `column` and reduce based on the values of y and the array indices of acc */
      (acc, y) ->
          WHEN y < 2.0 THEN array(IFNULL(acc[0],0) + y, acc[1], acc[2])
          WHEN y > 2.0 THEN array(acc[0], IFNULL(acc[1],0) + y, acc[2])
                       ELSE array(acc[0], acc[1], IFNULL(acc[2],0) + y)
      /* finish: to convert the array into a named_struct */
      acc -> (acc[0] as `column<2`, acc[1] as `column>2`, acc[2] as `column=2`)

""")).selectExpr('id', 'data.*').show()
#| id|column<2|column>2|column=2|
#|  1|     0.7|    12.1|    null|
#|  2|     0.6|    15.0|     2.0|

在 Spark 2.4 之前,对 ArrayType 的功能支持是有限的,你可以用 explode 然后 groupby+pivot:

from pyspark.sql.functions import sum as fsum, expr

df.selectExpr('id', 'explode_outer(column) as item') \
  .withColumn('g', expr('if(item < 2, "column<2", if(item > 2, "column>2", "column=2"))')) \
  .groupby('id') \
  .pivot('g', ["column<2", "column>2", "column=2"]) \
  .agg(fsum('item')) \
#| id|column<2|column>2|column=2|
#|  1|     0.7|    12.1|    null|
#|  2|     0.6|    15.0|     2.0|

如果 explode 很慢(即 Spark 2.3 之前显示的 SPARK-21657),请使用 UDF:

from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField, DoubleType

schema = StructType([
    StructField("column>2", DoubleType()), 
    StructField("column<2", DoubleType()),
    StructField("column=2", DoubleType())

def split_data(arr):
   d = {}
   if arr is None: arr = []
   for y in arr:
     if y > 2:
       d['column>2'] = d.get('column>2',0) + y
     elif y < 2:
       d['column<2'] = d.get('column<2',0) + y
       d['column=2'] = d.get('column=2',0) + y
   return d

udf_split_data = udf(split_data, schema)

df.withColumn('data', udf_split_data('column')).selectExpr('id', 'data.*').show()