PySpark:将不同的 window 大小应用于 pyspark 中的数据框
PySpark: applying varying window sizes to a dataframe in pyspark
我有一个如下所示的 spark 数据框。
date
ID
window_size
qty
01/01/2020
1
2
1
02/01/2020
1
2
2
03/01/2020
1
2
3
04/01/2020
1
2
4
01/01/2020
2
3
1
02/01/2020
2
3
2
03/01/2020
2
3
3
04/01/2020
2
3
4
我正在尝试对数据框中的每个 ID 应用大小为 window_size 的滚动 window 并获得滚动总和。基本上我正在计算滚动总和(pandas 中的 pd.groupby.rolling(window=n).sum()
),其中 window 大小 (n) 可以按组更改。
预期输出
date
ID
window_size
qty
rolling_sum
01/01/2020
1
2
1
null
02/01/2020
1
2
2
3
03/01/2020
1
2
3
5
04/01/2020
1
2
4
7
01/01/2020
2
3
1
null
02/01/2020
2
3
2
null
03/01/2020
2
3
3
6
04/01/2020
2
3
4
9
我正在努力寻找一种在大型数据帧(+- 350M 行)上运行且速度足够快的解决方案。
我试过的
我尝试了下面的解决方案 :
想法是先使用 sf.collect_list
,然后正确地分割 ArrayType
列。
import pyspark.sql.types as st
import pyspark.sql.function as sf
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()
但是这会产生以下错误:
TypeError: Column is not iterable
我也试过使用 sf.expr
如下
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()
产生:
data type mismatch: argument 1 requires array type, however,
''qty_list'' is of string type.; line 1 pos 0;
我尝试将 qty_list
列手动转换为 ArrayType(IntegerType())
,结果相同。
我尝试使用 UDF,但在 1.5 小时左右后失败并出现多个内存不足错误。
问题
阅读火花 documentation 向我建议我应该能够将列传递给 sf.slice()
,我做错了什么吗? TypeError
来自哪里?
有没有更好的方法不用sf.collect_list()
and/orsf.slice()
就可以达到我想要的效果?
如果所有其他方法都失败了,使用 udf 执行此操作的最佳方法是什么?我尝试了同一个 udf 的不同版本,并试图确保 udf 是 spark 必须执行的最后一个操作,但都失败了。
关于您遇到的错误:
- 第一个表示您不能使用 DataFrame API 函数将列传递给
slice
(除非您有 Spark 3.1+)。但是当您尝试在 SQL 表达式中使用它时,您已经掌握了它。
- 发生第二个错误是因为您传递了
expr
中引用的列名。它应该是 slice(qty_list, count, window_size)
,否则 Spark 会将它们视为字符串,因此会出现错误消息。
也就是说,您差不多明白了,您需要更改切片表达式以获得正确的数组大小,然后使用 aggregate
函数对结果数组的值求和。试试这个:
from pyspark.sql import Window
import pyspark.sql.functions as F
w = Window.partitionBy('id').orderBy('date')
output = df.withColumn("qty_list", F.collect_list('qty').over(w)) \
.withColumn("rn", F.row_number().over(w)) \
.withColumn(
"qty_list",
F.when(
F.col('rn') < F.col('window_size'),
None
).otherwise(F.expr("slice(qty_list, rn-window_size+1, window_size)"))
).withColumn(
"rolling_sum",
F.expr("aggregate(qty_list, 0D, (acc, x) -> acc + x)").cast("int")
).drop("qty_list", "rn")
output.show()
#+----------+---+-----------+---+-----------+
#| date| ID|window_size|qty|rolling_sum|
#+----------+---+-----------+---+-----------+
#|01/01/2020| 1| 2| 1| null|
#|02/01/2020| 1| 2| 2| 3|
#|03/01/2020| 1| 2| 3| 5|
#|04/01/2020| 1| 2| 4| 7|
#|01/01/2020| 2| 3| 1| null|
#|02/01/2020| 2| 3| 2| null|
#|03/01/2020| 2| 3| 3| 6|
#|04/01/2020| 2| 3| 4| 9|
#+----------+---+-----------+---+-----------+
我有一个如下所示的 spark 数据框。
date | ID | window_size | qty |
---|---|---|---|
01/01/2020 | 1 | 2 | 1 |
02/01/2020 | 1 | 2 | 2 |
03/01/2020 | 1 | 2 | 3 |
04/01/2020 | 1 | 2 | 4 |
01/01/2020 | 2 | 3 | 1 |
02/01/2020 | 2 | 3 | 2 |
03/01/2020 | 2 | 3 | 3 |
04/01/2020 | 2 | 3 | 4 |
我正在尝试对数据框中的每个 ID 应用大小为 window_size 的滚动 window 并获得滚动总和。基本上我正在计算滚动总和(pandas 中的 pd.groupby.rolling(window=n).sum()
),其中 window 大小 (n) 可以按组更改。
预期输出
date | ID | window_size | qty | rolling_sum |
---|---|---|---|---|
01/01/2020 | 1 | 2 | 1 | null |
02/01/2020 | 1 | 2 | 2 | 3 |
03/01/2020 | 1 | 2 | 3 | 5 |
04/01/2020 | 1 | 2 | 4 | 7 |
01/01/2020 | 2 | 3 | 1 | null |
02/01/2020 | 2 | 3 | 2 | null |
03/01/2020 | 2 | 3 | 3 | 6 |
04/01/2020 | 2 | 3 | 4 | 9 |
我正在努力寻找一种在大型数据帧(+- 350M 行)上运行且速度足够快的解决方案。
我试过的
我尝试了下面的解决方案
想法是先使用 sf.collect_list
,然后正确地分割 ArrayType
列。
import pyspark.sql.types as st
import pyspark.sql.function as sf
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()
但是这会产生以下错误:
TypeError: Column is not iterable
我也试过使用 sf.expr
如下
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()
产生:
data type mismatch: argument 1 requires array type, however, ''qty_list'' is of string type.; line 1 pos 0;
我尝试将 qty_list
列手动转换为 ArrayType(IntegerType())
,结果相同。
我尝试使用 UDF,但在 1.5 小时左右后失败并出现多个内存不足错误。
问题
阅读火花 documentation 向我建议我应该能够将列传递给
sf.slice()
,我做错了什么吗?TypeError
来自哪里?有没有更好的方法不用
sf.collect_list()
and/orsf.slice()
就可以达到我想要的效果?如果所有其他方法都失败了,使用 udf 执行此操作的最佳方法是什么?我尝试了同一个 udf 的不同版本,并试图确保 udf 是 spark 必须执行的最后一个操作,但都失败了。
关于您遇到的错误:
- 第一个表示您不能使用 DataFrame API 函数将列传递给
slice
(除非您有 Spark 3.1+)。但是当您尝试在 SQL 表达式中使用它时,您已经掌握了它。 - 发生第二个错误是因为您传递了
expr
中引用的列名。它应该是slice(qty_list, count, window_size)
,否则 Spark 会将它们视为字符串,因此会出现错误消息。
也就是说,您差不多明白了,您需要更改切片表达式以获得正确的数组大小,然后使用 aggregate
函数对结果数组的值求和。试试这个:
from pyspark.sql import Window
import pyspark.sql.functions as F
w = Window.partitionBy('id').orderBy('date')
output = df.withColumn("qty_list", F.collect_list('qty').over(w)) \
.withColumn("rn", F.row_number().over(w)) \
.withColumn(
"qty_list",
F.when(
F.col('rn') < F.col('window_size'),
None
).otherwise(F.expr("slice(qty_list, rn-window_size+1, window_size)"))
).withColumn(
"rolling_sum",
F.expr("aggregate(qty_list, 0D, (acc, x) -> acc + x)").cast("int")
).drop("qty_list", "rn")
output.show()
#+----------+---+-----------+---+-----------+
#| date| ID|window_size|qty|rolling_sum|
#+----------+---+-----------+---+-----------+
#|01/01/2020| 1| 2| 1| null|
#|02/01/2020| 1| 2| 2| 3|
#|03/01/2020| 1| 2| 3| 5|
#|04/01/2020| 1| 2| 4| 7|
#|01/01/2020| 2| 3| 1| null|
#|02/01/2020| 2| 3| 2| null|
#|03/01/2020| 2| 3| 3| 6|
#|04/01/2020| 2| 3| 4| 9|
#+----------+---+-----------+---+-----------+