每组重叠间隔切片之间的总和值

Sum value between overlapping interval slices per group

我有一个 pyspark 数据框,如下所示:

import pandas as pd
from pyspark.sql import SparkSession

spark = (SparkSession.builder 
         .master("local") 
         .getOrCreate())
spark.conf.set("spark.sql.session.timeZone", "UTC")

INPUT = {
    "idx": [1, 1, 1, 1, 0],
    "consumption": [10.0, 20.0, 30.0, 40.0, 5.0],
    "valid_from": [
        pd.Timestamp("2019-01-01 00:00:00+00:00", tz="UTC"),
        pd.Timestamp("2019-01-02 00:00:00+00:00", tz="UTC"),
        pd.Timestamp("2019-01-03 00:00:00+00:00", tz="UTC"),
        pd.Timestamp("2019-01-06 00:00:00+00:00", tz="UTC"),
        pd.Timestamp("2019-01-01 00:00:00+00:00", tz="UTC"),
    ],
    "valid_to": [
        pd.Timestamp("2019-01-02 00:00:00+0000", tz="UTC"),
        pd.Timestamp("2019-01-05 00:00:00+0000", tz="UTC"),
        pd.Timestamp("2019-01-05 00:00:00+0000", tz="UTC"),
        pd.Timestamp("2019-01-08 00:00:00+0000", tz="UTC"),
        pd.Timestamp("2019-01-02 00:00:00+00:00", tz="UTC"),
    ],
}
df=pd.DataFrame.from_dict(INPUT)
spark.createDataFrame(df).show()

>>>
   +---+-----------+-------------------+-------------------+
   |idx|consumption|         valid_from|           valid_to|
   +---+-----------+-------------------+-------------------+
   |  1|       10.0|2019-01-01 00:00:00|2019-01-02 00:00:00|
   |  1|       20.0|2019-01-02 00:00:00|2019-01-05 00:00:00|
   |  1|       30.0|2019-01-03 00:00:00|2019-01-05 00:00:00|
   |  1|       40.0|2019-01-06 00:00:00|2019-01-08 00:00:00|
   |  0|       5.0 |2019-01-01 00:00:00|2019-01-02 00:00:00|
   +---+-----------+-------------------+-------------------+

我只想对每个 idx 的重叠间隔切片 consumption 求和:

+---+-------------------+-----------+
|idx|          timestamp|consumption|
+---+-------------------+-----------+
|  1|2019-01-01 00:00:00|       10.0|
|  1|2019-01-02 00:00:00|       20.0|
|  1|2019-01-03 00:00:00|       50.0|
|  1|2019-01-04 00:00:00|       50.0|
|  1|2019-01-05 00:00:00|        0.0|
|  1|2019-01-06 00:00:00|       40.0|
|  1|2019-01-07 00:00:00|       40.0|
|  1|2019-01-08 00:00:00|        0.0|
|  0|2019-01-01 00:00:00|        5.0|
|  0|2019-01-02 00:00:00|        0.0|
+---+-------------------+-----------+

您可以对每个 timestampidx 使用 sequence to expand the intervals into single days, explode the list of days and then sum consumption:

from pyspark.sql import functions as F

input=spark.createDataFrame(df)
input.withColumn("all_days", F.sequence("valid_from", F.date_sub("valid_to", 1 ))) \
    .withColumn("timestamp", F.explode("all_days")) \
    .groupBy("idx", "timestamp").sum("consumption") \
    .withColumnRenamed("sum(consumption)", "consumption") \
    .join(input.select("idx", "valid_to").distinct().withColumnRenamed("idx", "idx2"), 
        (F.col("timestamp") == F.col("valid_to")) & (F.col("idx") == F.col("idx2")), "full_outer") \
    .withColumn("idx", F.coalesce("idx", "idx2")) \
    .withColumn("timestamp", F.coalesce("timestamp", "valid_to")) \
    .drop("idx2", "valid_to") \
    .fillna(0.0) \
    .orderBy("idx", "timestamp") \
    .show()

输出:


input=spark.createDataFrame(df)...
+---+-------------------+-----------+
|idx|          timestamp|consumption|
+---+-------------------+-----------+
|  0|2019-01-01 00:00:00|        5.0|
|  0|2019-01-02 00:00:00|        0.0|
|  1|2019-01-01 00:00:00|       10.0|
|  1|2019-01-02 00:00:00|       20.0|
|  1|2019-01-03 00:00:00|       50.0|
|  1|2019-01-04 00:00:00|       50.0|
|  1|2019-01-05 00:00:00|        0.0|
|  1|2019-01-06 00:00:00|       40.0|
|  1|2019-01-07 00:00:00|       40.0|
|  1|2019-01-08 00:00:00|        0.0|
+---+-------------------+-----------+

备注:

  • sequence 包括间隔的最后一个值,因此一天必须是 substractedvalid_to.
  • 然后使用与原始 valid_to 值的完整连接恢复丢失的间隔结束日期,用 0.0.
  • 填充 null