合并 Spark 数据框中的日期范围

Combine date ranges in Spark dataframe

我遇到了类似 的问题。

但是,我正在处理一个庞大的数据集。我试图看看我是否可以在 PySpark 而不是 pandas 中做同样的事情。下面是 pandas 中的解决方案。这可以在 PySpark 中完成吗?

def merge_dates(grp):
    # Find contiguous date groups, and get the first/last start/end date for each group.
    dt_groups = (grp['StartDate'] != grp['EndDate'].shift()).cumsum()
    return grp.groupby(dt_groups).agg({'StartDate': 'first', 'EndDate': 'last'})

# Perform a groupby and apply the merge_dates function, followed by formatting.
df = df.groupby(['FruitID', 'FruitType']).apply(merge_dates)
df = df.reset_index().drop('level_2', axis=1) 

我们可以使用 Windowlag 函数来计算连续的组,然后以与您分享的 Pandas 函数类似的方式聚合它们。下面给出了一个工作示例,希望对您有所帮助!

import pandas as pd
from dateutil.parser import parse
from pyspark.sql.window import Window
import pyspark.sql.functions as F


# EXAMPLE DATA -----------------------------------------------

pdf = pd.DataFrame.from_items([('FruitID', [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]),
                                ('FruitType', ['Apple', 'Apple', 'Apple', 'Orange', 'Orange', 'Orange', 'Banana', 'Banana', 'Blueberry', 'Mango', 'Kiwi', 'Mango']),
                                ('StartDate', [parse(x) for x in ['2015-01-01', '2016-01-01', '2017-01-01', '2015-01-01', '2016-05-31',
                                                                  '2017-01-01', '2015-01-01', '2016-01-01', '2017-01-01', '2015-01-01', '2016-09-15', '2017-01-01']]),
                                ('EndDate', [parse(x) for x in ['2016-01-01', '2017-01-01', '2018-01-01', '2016-01-01', '2017-01-01',
                                                                '2018-01-01', '2016-01-01', '2017-01-01', '2018-01-01', '2016-01-01', '2017-01-01', '2018-01-01']])
                                ])

pdf.sort_values(['FruitID', 'StartDate'])
df = sqlContext.createDataFrame(pdf)


# FIND CONTIGUOUS GROUPS AND AGGREGATE ---------------------

w = Window.partitionBy("FruitType").orderBy("StartDate")
contiguous = F.when(F.datediff(F.lag("EndDate", 1).over(w),F.col("StartDate"))!=0,F.lit(1)).otherwise(F.lit(0))
df = (df
      .withColumn('contiguous_grp', F.sum(contiguous).over(w))
      .groupBy('FruitType','contiguous_grp')
      .agg(F.first('StartDate').alias('StartDate'),F.last('EndDate').alias('EndDate'))
      .drop('contiguous_grp'))
df.show()

输出:

+---------+-------------------+-------------------+
|FruitType|          StartDate|            EndDate|
+---------+-------------------+-------------------+
|   Orange|2015-01-01 00:00:00|2016-01-01 00:00:00|
|   Orange|2016-05-31 00:00:00|2018-01-01 00:00:00|
|   Banana|2015-01-01 00:00:00|2017-01-01 00:00:00|
|     Kiwi|2016-09-15 00:00:00|2017-01-01 00:00:00|
|    Mango|2015-01-01 00:00:00|2016-01-01 00:00:00|
|    Mango|2017-01-01 00:00:00|2018-01-01 00:00:00|
|    Apple|2015-01-01 00:00:00|2018-01-01 00:00:00|
|Blueberry|2017-01-01 00:00:00|2018-01-01 00:00:00|
+---------+-------------------+-------------------+

合并日期范围时,需要考虑以下几点:

  • 其他范围内的范围
  • 空值
  • 日期范围之间可接受的差距大小(您是否也需要“接触”日期范围来组合?)

以下所有脚本都适用于其他范围内的范围。


选项 1

works when start_date and end_date columns don't contain null values works when start_date and end_date columns contain null values*
combines overlapping date ranges
combines "touching" (consecutive) date ranges

*第start_date列为空被认为是最早的日期;
*列 end_date 中的空值被视为最新日期。

w = Window.partitionBy("id").orderBy("start_date")
contiguous = F.when(F.datediff(F.lag("end_date", 1).over(w), "start_date") < -1, 1).otherwise(0)
df = (df0
      .withColumn("contiguous_grp", F.sum(contiguous).over(w))
      .groupBy("id", "contiguous_grp")
      .agg(
          F.first("start_date").alias("start_date"),
          F.when(F.expr("any(end_date is null)"), None).otherwise(F.max("end_date")).alias("end_date"))
      .drop("contiguous_grp"))

示例数据框:

from pyspark.sql import SparkSession, functions as F, Window
spark = SparkSession.builder.getOrCreate()

data = [("separate",           "2022-01-01", "2022-01-09"),
        ("separate",           "2022-01-11", "2022-01-20"),
        ("consecutive",        "2022-02-01", "2022-02-10"),
        ("consecutive",        "2022-02-11", "2022-02-20"),
        ("overlapping by 1",   "2022-03-01", "2022-03-11"),
        ("overlapping by 1",   "2022-03-11", "2022-03-20"),
        ("overlapping by 2",   "2022-04-01", "2022-04-12"),
        ("overlapping by 2",   "2022-04-11", "2022-04-20"),
        ("inside",             "2022-05-01", "2022-05-20"),
        ("inside",             "2022-05-02", "2022-05-19"),
        ("common_start",       "2022-06-01", "2022-06-20"),
        ("common_start",       "2022-06-01", "2022-06-19"),
        ("common_end",         "2022-07-01", "2022-07-20"),
        ("common_end",         "2022-07-02", "2022-07-20"),
        ("N separate",                 None, "2022-01-09"),
        ("N separate",         "2022-01-11",         None),
        ("N consecutive",              None, "2022-02-10"),
        ("N consecutive",      "2022-02-11",         None),
        ("N overlapping by 1",         None, "2022-03-11"),
        ("N overlapping by 1", "2022-03-11",         None),
        ("N overlapping by 2",         None, "2022-04-12"),
        ("N overlapping by 2", "2022-04-11",         None),
        ("N inside",                   None,         None),
        ("N inside",           "2022-05-02", "2022-05-19"),
        ("N common_start",             None, "2022-06-20"),
        ("N common_start",             None, "2022-06-19"),
        ("N common_end",       "2022-07-01",         None),
        ("N common_end",       "2022-07-02",         None)]
df0 = spark.createDataFrame(data, ["id", "start_date", "end_date"]).select(
    "id",
    F.col("start_date").cast("date"),
    F.col("end_date").cast("date")
)

结果:

+------------------+----------+----------+
|                id|start_date|  end_date|
+------------------+----------+----------+
|      N common_end|2022-07-01|      null|
|    N common_start|      null|2022-06-20|
|     N consecutive|      null|      null|
|          N inside|      null|      null|
|N overlapping by 1|      null|      null|
|N overlapping by 2|      null|      null|
|        N separate|      null|2022-01-09|
|        N separate|2022-01-11|      null|
|        common_end|2022-07-01|2022-07-20|
|      common_start|2022-06-01|2022-06-20|
|       consecutive|2022-02-01|2022-02-20|
|            inside|2022-05-01|2022-05-20|
|  overlapping by 1|2022-03-01|2022-03-20|
|  overlapping by 2|2022-04-01|2022-04-20|
|          separate|2022-01-01|2022-01-09|
|          separate|2022-01-11|2022-01-20|
+------------------+----------+----------+

选项 2

works when start_date and end_date columns don't contain null values works when start_date and end_date columns contain null values*
combines overlapping date ranges
combines "touching" (consecutive) date ranges

*第start_date列为空被认为是最早的日期;
*列 end_date 中的空值被视为最新日期。

w = Window.partitionBy("id").orderBy("start_date")
contiguous = F.when(F.datediff(F.lag("end_date", 1).over(w), "start_date") < 0, 1).otherwise(0)
df = (df0
      .withColumn("contiguous_grp", F.sum(contiguous).over(w))
      .groupBy("id", "contiguous_grp")
      .agg(
          F.first("start_date").alias("start_date"),
          F.when(F.expr("any(end_date is null)"), None).otherwise(F.max("end_date")).alias("end_date"))
      .drop("contiguous_grp"))

选项 3

works when start_date and end_date columns don't contain null values* works when start_date and end_date columns contain null values
combines overlapping date ranges
combines "touching" (consecutive) date ranges

*空值如果存在则被忽略。

w = Window.partitionBy("id").orderBy("start_date")
contiguous = F.when(F.datediff(F.lag("end_date", 1).over(w), "start_date") < -1, 1).otherwise(0)
df = (df0
      .withColumn("contiguous_grp", F.sum(contiguous).over(w))
      .groupBy("id", "contiguous_grp")
      .agg(F.min("start_date").alias("start_date"), F.max("end_date").alias("end_date"))
      .drop("contiguous_grp"))

选项 4

works when start_date and end_date columns don't contain null values* works when start_date and end_date columns contain null values
combines overlapping date ranges
combines "touching" (consecutive) date ranges

*空值如果存在则被忽略。

w = Window.partitionBy("id").orderBy("start_date")
contiguous = F.when(F.datediff(F.lag("end_date", 1).over(w), "start_date") < 0, 1).otherwise(0)
df = (df0
      .withColumn("contiguous_grp", F.sum(contiguous).over(w))
      .groupBy("id", "contiguous_grp")
      .agg(F.min("start_date").alias("start_date"), F.max("end_date").alias("end_date"))
      .drop("contiguous_grp"))