在 Spark Dataframe 中查找列的最大两个下一个值的平均值

Find average of max two next values of column in a Spark Dataframe

下面是示例数据框:

id  nd  time    value   
3   n1  7       50  
10  n1  3       40  
11  n1  5       30  
1   n1  2       20  
2   n1  6       20  
9   n1  4       10  
4   n1  1       10

此处最大时间为 7,我必须找到 time 小于 7 的两个最大值之和:最大值为 40 和 30,然后计算 newValue = value - avg(30,40) = 50 - (30 + 40)/2 = 15

现在下一个最大值 time 是 6。所以我必须找到两个小于 6 的最大值。(同样是 30 和 40,所以 newValue = 20 - avg(30,40) = -15

同样,我必须找到所有值,直到最后两个我们必须给出 null 的值。

id  nd  time    value    NewVal
3   n1  7       50       15
10  n1  3       40       25
11  n1  5       30       0 ((40+20)/2)(30-30)
1   n1  2       20       Null
2   n1  6       20      -15
9   n1  4       10       20
4   n1  1       10       Null

我已经编写了 UDF 来解决您的问题。根据您的逻辑,针对时间 4 的 NewValue 将是 -20 而不是 20。这在我的代码中是正确的。请确认。

>>> from pyspark.sql.types import StringType
>>> from pyspark.sql.functions import udf,col,concat_ws,collect_list
>>> from pyspark.sql.window import Window
>>> df.show()
+---+---+----+-----+
| id| nd|time|value|
+---+---+----+-----+
|  3| n1|   7|   50|
| 10| n1|   3|   40|
| 11| n1|   5|   30|
|  1| n1|   2|   20|
|  2| n1|   6|   20|
|  9| n1|   4|   10|
|  4| n1|   1|   10|
+---+---+----+-----+

>>> df.cache()
>>> cnt = df.count()
>>> def sampleFun(allvalue,value):
...     output = ''
...     firstValue = allvalue.replace(value + ',','', 1)
...     firstList =  [int(i) for i in firstValue.split(',')]
...     if len(firstList) > 1:
...             max_1 = max(firstList)
...             secondValue = firstValue.replace(str(max_1) + ',','', 1)
...             secondList = [int(i) for i in secondValue.split(",")]
...             max_2 = max(secondList)
...             avgValue = (max_1 + max_2)/2
...             output = (int(value) - avgValue)
...             return str(output)
...     else:
...             return ''

>>> sampleUDF = udf(sampleFun, StringType())
>>> W = Window.rowsBetween(0,cnt).orderBy(col("time").desc())
>>> df1 = df.withColumn("ListValue", concat_ws(",",collect_list(col("value")).over(W)))
>>> df2 = df1.withColumn("NewValue", sampleUDF(col("ListValue"), col("value"))).drop("ListValue")
>>> df2.show()
+---+---+----+-----+--------+                                                   
| id| nd|time|value|NewValue|
+---+---+----+-----+--------+
|  3| n1|   7|   50|    15.0|
|  2| n1|   6|   20|   -15.0|
| 11| n1|   5|   30|     0.0|
|  9| n1|   4|   10|   -20.0|
| 10| n1|   3|   40|    25.0|
|  1| n1|   2|   20|        |
|  4| n1|   1|   10|        |
+---+---+----+-----+--------+

如果数据可以正确分区,我会使用 Window 函数,例如,在您的示例中使用 nd 列:(或者如果您的数据可以加载到一个分区中,其中情况下,从下面的 WindowSpec w1)

中删除 partitionBy('nd')
from pyspark.sql.functions import sort_array, collect_list, expr
from pyspark.sql import Window 

w1 = Window.partitionBy('nd').orderBy('time').rowsBetween(Window.unboundedPreceding, -1) 

df.withColumn('v1', sort_array(collect_list('value').over(w1),False)) \
    .withColumn('NewVal', expr('value - (v1[0] + v1[1])*0.5')) \
    .show(10, False)                                                                        
+---+---+----+-----+------------------------+------+                            
|id |nd |time|value|v1                      |NewVal|
+---+---+----+-----+------------------------+------+
|4  |n1 |1   |10   |[]                      |null  |
|1  |n1 |2   |20   |[10]                    |null  |
|10 |n1 |3   |40   |[20, 10]                |25.0  |
|9  |n1 |4   |10   |[40, 20, 10]            |-20.0 |
|11 |n1 |5   |30   |[40, 20, 10, 10]        |0.0   |
|2  |n1 |6   |20   |[40, 30, 20, 10, 10]    |-15.0 |
|3  |n1 |7   |50   |[40, 30, 20, 20, 10, 10]|15.0  |
+---+---+----+-----+------------------------+------+

更新: 计算任意 N 最大值的平均值:

from pyspark.sql.functions import sort_array, collect_list, col, round                                      

N = 3

df.withColumn('v1', sort_array(collect_list('value').over(w1),False)) \
    .withColumn('NewVal', round(col('value') - sum(col('v1')[i] for i in range(N))/N,2)) \
    .show(10, False)