在 Spark Dataframe 中查找列的最大两个下一个值的平均值
Find average of max two next values of column in a Spark Dataframe
下面是示例数据框:
id nd time value
3 n1 7 50
10 n1 3 40
11 n1 5 30
1 n1 2 20
2 n1 6 20
9 n1 4 10
4 n1 1 10
此处最大时间为 7,我必须找到 time
小于 7 的两个最大值之和:最大值为 40 和 30,然后计算 newValue = value - avg(30,40) = 50 - (30 + 40)/2 = 15
。
现在下一个最大值 time
是 6。所以我必须找到两个小于 6 的最大值。(同样是 30 和 40,所以 newValue = 20 - avg(30,40) = -15
)
同样,我必须找到所有值,直到最后两个我们必须给出 null 的值。
id nd time value NewVal
3 n1 7 50 15
10 n1 3 40 25
11 n1 5 30 0 ((40+20)/2)(30-30)
1 n1 2 20 Null
2 n1 6 20 -15
9 n1 4 10 20
4 n1 1 10 Null
我已经编写了 UDF 来解决您的问题。根据您的逻辑,针对时间 4 的 NewValue 将是 -20 而不是 20。这在我的代码中是正确的。请确认。
>>> from pyspark.sql.types import StringType
>>> from pyspark.sql.functions import udf,col,concat_ws,collect_list
>>> from pyspark.sql.window import Window
>>> df.show()
+---+---+----+-----+
| id| nd|time|value|
+---+---+----+-----+
| 3| n1| 7| 50|
| 10| n1| 3| 40|
| 11| n1| 5| 30|
| 1| n1| 2| 20|
| 2| n1| 6| 20|
| 9| n1| 4| 10|
| 4| n1| 1| 10|
+---+---+----+-----+
>>> df.cache()
>>> cnt = df.count()
>>> def sampleFun(allvalue,value):
... output = ''
... firstValue = allvalue.replace(value + ',','', 1)
... firstList = [int(i) for i in firstValue.split(',')]
... if len(firstList) > 1:
... max_1 = max(firstList)
... secondValue = firstValue.replace(str(max_1) + ',','', 1)
... secondList = [int(i) for i in secondValue.split(",")]
... max_2 = max(secondList)
... avgValue = (max_1 + max_2)/2
... output = (int(value) - avgValue)
... return str(output)
... else:
... return ''
>>> sampleUDF = udf(sampleFun, StringType())
>>> W = Window.rowsBetween(0,cnt).orderBy(col("time").desc())
>>> df1 = df.withColumn("ListValue", concat_ws(",",collect_list(col("value")).over(W)))
>>> df2 = df1.withColumn("NewValue", sampleUDF(col("ListValue"), col("value"))).drop("ListValue")
>>> df2.show()
+---+---+----+-----+--------+
| id| nd|time|value|NewValue|
+---+---+----+-----+--------+
| 3| n1| 7| 50| 15.0|
| 2| n1| 6| 20| -15.0|
| 11| n1| 5| 30| 0.0|
| 9| n1| 4| 10| -20.0|
| 10| n1| 3| 40| 25.0|
| 1| n1| 2| 20| |
| 4| n1| 1| 10| |
+---+---+----+-----+--------+
如果数据可以正确分区,我会使用 Window 函数,例如,在您的示例中使用 nd
列:(或者如果您的数据可以加载到一个分区中,其中情况下,从下面的 WindowSpec w1
)
中删除 partitionBy('nd')
from pyspark.sql.functions import sort_array, collect_list, expr
from pyspark.sql import Window
w1 = Window.partitionBy('nd').orderBy('time').rowsBetween(Window.unboundedPreceding, -1)
df.withColumn('v1', sort_array(collect_list('value').over(w1),False)) \
.withColumn('NewVal', expr('value - (v1[0] + v1[1])*0.5')) \
.show(10, False)
+---+---+----+-----+------------------------+------+
|id |nd |time|value|v1 |NewVal|
+---+---+----+-----+------------------------+------+
|4 |n1 |1 |10 |[] |null |
|1 |n1 |2 |20 |[10] |null |
|10 |n1 |3 |40 |[20, 10] |25.0 |
|9 |n1 |4 |10 |[40, 20, 10] |-20.0 |
|11 |n1 |5 |30 |[40, 20, 10, 10] |0.0 |
|2 |n1 |6 |20 |[40, 30, 20, 10, 10] |-15.0 |
|3 |n1 |7 |50 |[40, 30, 20, 20, 10, 10]|15.0 |
+---+---+----+-----+------------------------+------+
更新: 计算任意 N 最大值的平均值:
from pyspark.sql.functions import sort_array, collect_list, col, round
N = 3
df.withColumn('v1', sort_array(collect_list('value').over(w1),False)) \
.withColumn('NewVal', round(col('value') - sum(col('v1')[i] for i in range(N))/N,2)) \
.show(10, False)
下面是示例数据框:
id nd time value
3 n1 7 50
10 n1 3 40
11 n1 5 30
1 n1 2 20
2 n1 6 20
9 n1 4 10
4 n1 1 10
此处最大时间为 7,我必须找到 time
小于 7 的两个最大值之和:最大值为 40 和 30,然后计算 newValue = value - avg(30,40) = 50 - (30 + 40)/2 = 15
。
现在下一个最大值 time
是 6。所以我必须找到两个小于 6 的最大值。(同样是 30 和 40,所以 newValue = 20 - avg(30,40) = -15
)
同样,我必须找到所有值,直到最后两个我们必须给出 null 的值。
id nd time value NewVal
3 n1 7 50 15
10 n1 3 40 25
11 n1 5 30 0 ((40+20)/2)(30-30)
1 n1 2 20 Null
2 n1 6 20 -15
9 n1 4 10 20
4 n1 1 10 Null
我已经编写了 UDF 来解决您的问题。根据您的逻辑,针对时间 4 的 NewValue 将是 -20 而不是 20。这在我的代码中是正确的。请确认。
>>> from pyspark.sql.types import StringType
>>> from pyspark.sql.functions import udf,col,concat_ws,collect_list
>>> from pyspark.sql.window import Window
>>> df.show()
+---+---+----+-----+
| id| nd|time|value|
+---+---+----+-----+
| 3| n1| 7| 50|
| 10| n1| 3| 40|
| 11| n1| 5| 30|
| 1| n1| 2| 20|
| 2| n1| 6| 20|
| 9| n1| 4| 10|
| 4| n1| 1| 10|
+---+---+----+-----+
>>> df.cache()
>>> cnt = df.count()
>>> def sampleFun(allvalue,value):
... output = ''
... firstValue = allvalue.replace(value + ',','', 1)
... firstList = [int(i) for i in firstValue.split(',')]
... if len(firstList) > 1:
... max_1 = max(firstList)
... secondValue = firstValue.replace(str(max_1) + ',','', 1)
... secondList = [int(i) for i in secondValue.split(",")]
... max_2 = max(secondList)
... avgValue = (max_1 + max_2)/2
... output = (int(value) - avgValue)
... return str(output)
... else:
... return ''
>>> sampleUDF = udf(sampleFun, StringType())
>>> W = Window.rowsBetween(0,cnt).orderBy(col("time").desc())
>>> df1 = df.withColumn("ListValue", concat_ws(",",collect_list(col("value")).over(W)))
>>> df2 = df1.withColumn("NewValue", sampleUDF(col("ListValue"), col("value"))).drop("ListValue")
>>> df2.show()
+---+---+----+-----+--------+
| id| nd|time|value|NewValue|
+---+---+----+-----+--------+
| 3| n1| 7| 50| 15.0|
| 2| n1| 6| 20| -15.0|
| 11| n1| 5| 30| 0.0|
| 9| n1| 4| 10| -20.0|
| 10| n1| 3| 40| 25.0|
| 1| n1| 2| 20| |
| 4| n1| 1| 10| |
+---+---+----+-----+--------+
如果数据可以正确分区,我会使用 Window 函数,例如,在您的示例中使用 nd
列:(或者如果您的数据可以加载到一个分区中,其中情况下,从下面的 WindowSpec w1
)
partitionBy('nd')
from pyspark.sql.functions import sort_array, collect_list, expr
from pyspark.sql import Window
w1 = Window.partitionBy('nd').orderBy('time').rowsBetween(Window.unboundedPreceding, -1)
df.withColumn('v1', sort_array(collect_list('value').over(w1),False)) \
.withColumn('NewVal', expr('value - (v1[0] + v1[1])*0.5')) \
.show(10, False)
+---+---+----+-----+------------------------+------+
|id |nd |time|value|v1 |NewVal|
+---+---+----+-----+------------------------+------+
|4 |n1 |1 |10 |[] |null |
|1 |n1 |2 |20 |[10] |null |
|10 |n1 |3 |40 |[20, 10] |25.0 |
|9 |n1 |4 |10 |[40, 20, 10] |-20.0 |
|11 |n1 |5 |30 |[40, 20, 10, 10] |0.0 |
|2 |n1 |6 |20 |[40, 30, 20, 10, 10] |-15.0 |
|3 |n1 |7 |50 |[40, 30, 20, 20, 10, 10]|15.0 |
+---+---+----+-----+------------------------+------+
更新: 计算任意 N 最大值的平均值:
from pyspark.sql.functions import sort_array, collect_list, col, round
N = 3
df.withColumn('v1', sort_array(collect_list('value').over(w1),False)) \
.withColumn('NewVal', round(col('value') - sum(col('v1')[i] for i in range(N))/N,2)) \
.show(10, False)