使用 PySpark 对多列执行延迟

Perform Lag over multiple columns using PySpark

我是 PySpark 的新手,但我正在尝试在我的代码中使用最佳实践。我有一个 PySpark 数据框,我想滞后多列,用滞后值替换原始值。示例:

ID     date        value1     value2     value3
1      2021-12-23  1.1        4.0        2.2
2      2021-12-21  2.4        1.6        11.9
1      2021-12-24  5.4        3.2        7.8
2      2021-12-22  4.2        1.4        9.0
1      2021-12-26  2.3        5.2        7.6
.
.
.

我想根据 ID 获取所有值,按 date 对它们进行排序,然后将这些值滞后一些。我到目前为止的代码:

from pyspark.sql import functions as F, Window

window = Window.partitionBy(F.col("ID")).orderBy(F.col("date"))

valueColumns = ['value1', 'value2', 'value3']

df = F.lag(valueColumns, offset=shiftAmount).over(window)

我想要的输出是:

ID     date        value1     value2     value3
1      2021-12-23  Null       Null       Null
2      2021-12-21  Null       Null       Null
1      2021-12-24  1.1        4.0        2.2
2      2021-12-22  2.4        1.6        11.9
1      2021-12-26  5.4        3.2        7.86
.
.
.

我遇到的问题是,据我所知,F.lag 只接受一个列。我正在寻找有关如何最好地完成此任务的建议。我想我可以使用 for 循环来附加移位的列或其他东西,但这看起来很不雅观。谢谢!

对列名的简单列表理解应该可以完成这项工作:

df = df.select(
    "ID", "date",
    *[F.lag(c, offset=shiftAmount).over(window).alias(c) for c in valueColumns]
)