根据不同列中的多个值在 Spark df 中创建一个列

Create a column in Spark df based on several values in different column

我有一个大致如下所示的 spark df:

company ID  quarter     metric
  12       31-12-2019    54.3     
  12       30-09-2019    48.2
  12       30-06-2019    32.3
  12       30-03-2019    54.3
  23       31-12-2018    54.3
  23       30-09-2018    48.2
  23       30-06-2018    32.3
  23       30-03-2018    54.3
  45       31-12-2021    54.3
  45       30-09-2021    48.2
  45       30-06-2021    32.3
  45       30-03-2021    54.3
  45       31-12-2021    54.3
  45       30-09-2020    48.2
  45       30-06-2020    32.3
  45       30-03-2020    54.3
  ..           ..         ..

对于每个公司 ID 的每个季度行,我需要计算以下季度的年值,即对于公司 ID = 45 和季度 = 30-06-2020,年值将等于:

30-03-2021    54.3
31-12-2020    54.3
30-09-2020    48.2
30-06-2020    32.3
            --------
              189,1

结果:

   company ID  quarter     metric   annual
      12       31-12-2019    54.3     
      12       30-09-2019    48.2
      12       30-06-2019    32.3
      12       30-03-2019    54.3
      23       31-12-2018    54.3
      23       30-09-2018    48.2
      23       30-06-2018    32.3
      23       30-03-2018    54.3
      45       31-12-2021    54.3
      45       30-09-2021    48.2
      45       30-06-2021    32.3
      45       30-03-2021    54.3
      45       31-12-2021    54.3
      45       30-09-2020    48.2
      45       30-06-2020    32.3   **189,1**
      45       30-03-2020    54.3
      ..           ..         ..

在 pandas 中,我可能会按实体 ID 分组,然后尝试根据索引或类似的东西计算列。在 Spark/Python 中最有效的方法是什么?

日期可以转换为自 1970 年以来的天数,并且可以在 Scala 上使用 Window 范围为 366 天的带有“总和”的函数:

val df = Seq(
  (12, "31-12-2019", 54.3),
  (12, "30-09-2019", 48.2),
  (12, "30-06-2019", 32.3),
  (12, "30-03-2019", 54.3),
  (23, "31-12-2018", 54.3),
  (23, "30-09-2018", 48.2),
  (23, "30-06-2018", 32.3),
  (23, "30-03-2018", 54.3),
  (45, "31-12-2021", 54.3),
  (45, "30-09-2021", 48.2),
  (45, "30-06-2021", 32.3),
  (45, "30-03-2021", 54.3),
  (45, "31-12-2021", 54.3),
  (45, "30-09-2020", 48.2),
  (45, "30-06-2020", 32.3),
  (45, "30-03-2020", 54.3),
)
  .toDF("company ID", "quarter", "metric")

val companyIdWindow = Window.partitionBy("company ID").rangeBetween(-366, Window.currentRow).orderBy("days")

import java.util.concurrent.TimeUnit
val secondsInDay = TimeUnit.DAYS.toSeconds(1)
df
  .withColumn("days", unix_timestamp($"quarter", "dd-MM-yyyy") / secondsInDay)
  .withColumn("annual", sum("metric").over(companyIdWindow))
  .drop("days")

结果:

+----------+----------+------+------------------+
|company ID|quarter   |metric|annual            |
+----------+----------+------+------------------+
|23        |30-03-2018|54.3  |54.3              |
|23        |30-06-2018|32.3  |86.6              |
|23        |30-09-2018|48.2  |134.8             |
|23        |31-12-2018|54.3  |189.10000000000002|
|45        |30-03-2020|54.3  |54.3              |
|45        |30-06-2020|32.3  |86.6              |
|45        |30-09-2020|48.2  |134.8             |
|45        |30-03-2021|54.3  |189.10000000000002|
|45        |30-06-2021|32.3  |167.10000000000002|
|45        |30-09-2021|48.2  |183.0             |
|45        |31-12-2021|54.3  |243.40000000000003|
|45        |31-12-2021|54.3  |243.40000000000003|
|12        |30-03-2019|54.3  |54.3              |
|12        |30-06-2019|32.3  |86.6              |
|12        |30-09-2019|48.2  |134.8             |
|12        |31-12-2019|54.3  |189.10000000000002|
+----------+----------+------+------------------+