PySpark 多列使用 Windows

PySpark Multiple Columns Using Windows

我有一个数据框如下:

|id |date_1    |date_2     |
+---+----------+-----------+
|0  |2017-01-21|2017-04-01 |
|1  |2017-01-22|2017-04-24 |
|2  |2017-02-23|2017-04-30 |
|3  |2017-02-27|2017-04-30 |
|4  |2017-04-23|2017-05-27 |
|5  |2017-04-29|2017-06-30 |
|6  |2017-06-13|2017-07-05 |
|7  |2017-06-13|2017-07-18 |
|8  |2017-06-16|2017-07-19 |
|9  |2017-07-09|2017-08-02 |
|10 |2017-07-18|2017-08-07 |
|11 |2017-07-28|2017-08-11 |
|12 |2017-07-28|2017-08-13 |
|13 |2017-08-04|2017-08-13 |
|14 |2017-08-13|2017-08-13 |
|15 |2017-08-13|2017-08-13 |
|16 |2017-08-13|2017-08-25 |
|17 |2017-08-13|2017-09-10 |
|18 |2017-08-31|2017-09-21 |
|19 |2017-10-03|2017-09-22 |
+---+----------+-----------+

我知道有很多方法可以使用不同的 pyspark API 来完成我所要求的,但是我想使用 Window API 来完成以下任务。

在任何其他情况下,它本质上都是一个双循环。

对于 date_1 中的每个日期,查看 date_2 中相同或后续行中的每个日期,并计算差异落在一周、一个月、 ...,(时间范围无关紧要,但为了保持一致性,让我们以一周为单位)。使用这些结果添加另一列计数。

挑战在于获得正确的 Window(s) 组合来考虑两个日期列。

如果我正确理解数据框中每一行 X 的问题作者,我们希望遍历从该行开始的所有行(按例如 id 排序)并针对每一行 Y 比较 X。date_1 与 Y.date_2。 X.date_1 和 Y.date_2 之间的差异小于的行数 Y,例如应将 1 周作为列添加到第 X 行(例如 X.result)。

不幸的是,windowing 函数没有提供在 window 函数内部访问 X.date_1 的功能,因此无法使用 windowing 函数实现。

这似乎与 作者试图为 Postgres 做类似的事情非常相似。

但是有一种方法可以通过一些作弊来实际做到这一点 - 即为数组中的每一行设置 "materialize" window 框架,然后执行所需的操作。不确定在您看来这是否有效,但这是 Window API 可用于解决问题的唯一方法。一个可能的解决方案可能是这样的(假设我们要计算不早于 X w.r.t idY.date_2X.date_1X.date_1 + 7 days 之间的行数 Y ):

import datetime
rawdata = [l.strip('|').replace('|', ' ').split() for l in '''|0  |2017-01-21|2017-04-01 |
|1  |2017-01-22|2017-04-24 |
|2  |2017-02-23|2017-04-30 |
|3  |2017-02-27|2017-04-30 |
|4  |2017-04-23|2017-05-27 |
|5  |2017-04-29|2017-06-30 |
|6  |2017-06-13|2017-07-05 |
|7  |2017-06-13|2017-07-18 |
|8  |2017-06-16|2017-07-19 |
|9  |2017-07-09|2017-08-02 |
|10 |2017-07-18|2017-08-07 |
|11 |2017-07-28|2017-08-11 |
|12 |2017-07-28|2017-08-13 |
|13 |2017-08-04|2017-08-13 |
|14 |2017-08-13|2017-08-13 |
|15 |2017-08-13|2017-08-13 |
|16 |2017-08-13|2017-08-25 |
|17 |2017-08-13|2017-09-10 |
|18 |2017-08-31|2017-09-21 |
|19 |2017-10-03|2017-09-22 |'''.split('\n')]
data = [(int(d[0]), datetime.date.fromisoformat(d[1]), datetime.date.fromisoformat(d[2])) for d in rawdata]
df = spark.createDataFrame(data, schema='id: bigint, date_1: Date, date_2: Date')

from pyspark.sql.window import Window
import pyspark.sql.functions as func
window_spec = Window.orderBy('id').rowsBetween(Window.currentRow, Window.unboundedFollowing)
new_df = df.withColumn('materialized_frame_date_2', func.collect_list(df['date_2']).over(window_spec)) \
  .withColumn('result', func.expr('size(filter(materialized_frame_date_2, x -> datediff(x, date_1) BETWEEN 0 AND 7))')) \
  .drop('materialized_frame_date_2')
new_df.show()

结果:

+---+----------+----------+------+
| id|    date_1|    date_2|result|
+---+----------+----------+------+
|  0|2017-01-21|2017-04-01|     0|
|  1|2017-01-22|2017-04-24|     0|
|  2|2017-02-23|2017-04-30|     0|
|  3|2017-02-27|2017-04-30|     0|
|  4|2017-04-23|2017-05-27|     0|
|  5|2017-04-29|2017-06-30|     0|
|  6|2017-06-13|2017-07-05|     0|
|  7|2017-06-13|2017-07-18|     0|
|  8|2017-06-16|2017-07-19|     0|
|  9|2017-07-09|2017-08-02|     0|
| 10|2017-07-18|2017-08-07|     0|
| 11|2017-07-28|2017-08-11|     0|
| 12|2017-07-28|2017-08-13|     0|
| 13|2017-08-04|2017-08-13|     0|
| 14|2017-08-13|2017-08-13|     2|
| 15|2017-08-13|2017-08-13|     1|
| 16|2017-08-13|2017-08-25|     0|
| 17|2017-08-13|2017-09-10|     0|
| 18|2017-08-31|2017-09-21|     0|
| 19|2017-10-03|2017-09-22|     0|
+---+----------+----------+------+

也许这有帮助-

加载提供的测试数据

 val data =
      """
        |id |date_1    |date_2
        |0  |2017-01-21|2017-04-01
        |1  |2017-01-22|2017-04-24
        |2  |2017-02-23|2017-04-30
        |3  |2017-02-27|2017-04-30
        |4  |2017-04-23|2017-05-27
        |5  |2017-04-29|2017-06-30
        |6  |2017-06-13|2017-07-05
        |7  |2017-06-13|2017-07-18
        |8  |2017-06-16|2017-07-19
        |9  |2017-07-09|2017-08-02
        |10 |2017-07-18|2017-08-07
        |11 |2017-07-28|2017-08-11
        |12 |2017-07-28|2017-08-13
        |13 |2017-08-04|2017-08-13
        |14 |2017-08-13|2017-08-13
        |15 |2017-08-13|2017-08-13
        |16 |2017-08-13|2017-08-25
        |17 |2017-08-13|2017-09-10
        |18 |2017-08-31|2017-09-21
        |19 |2017-10-03|2017-09-22
      """.stripMargin

    val stringDS = data.split(System.lineSeparator())
      .map(_.split("\|").map(_.replaceAll("""^[ \t]+|[ \t]+$""", "")).mkString(","))
      .toSeq.toDS()
    val df = spark.read
      .option("sep", ",")
      .option("inferSchema", "true")
      .option("header", "true")
      .option("nullValue", "null")
      .csv(stringDS)

    df.show(false)
    df.printSchema()
    /**
      * +---+-------------------+-------------------+
      * |id |date_1             |date_2             |
      * +---+-------------------+-------------------+
      * |0  |2017-01-21 00:00:00|2017-04-01 00:00:00|
      * |1  |2017-01-22 00:00:00|2017-04-24 00:00:00|
      * |2  |2017-02-23 00:00:00|2017-04-30 00:00:00|
      * |3  |2017-02-27 00:00:00|2017-04-30 00:00:00|
      * |4  |2017-04-23 00:00:00|2017-05-27 00:00:00|
      * |5  |2017-04-29 00:00:00|2017-06-30 00:00:00|
      * |6  |2017-06-13 00:00:00|2017-07-05 00:00:00|
      * |7  |2017-06-13 00:00:00|2017-07-18 00:00:00|
      * |8  |2017-06-16 00:00:00|2017-07-19 00:00:00|
      * |9  |2017-07-09 00:00:00|2017-08-02 00:00:00|
      * |10 |2017-07-18 00:00:00|2017-08-07 00:00:00|
      * |11 |2017-07-28 00:00:00|2017-08-11 00:00:00|
      * |12 |2017-07-28 00:00:00|2017-08-13 00:00:00|
      * |13 |2017-08-04 00:00:00|2017-08-13 00:00:00|
      * |14 |2017-08-13 00:00:00|2017-08-13 00:00:00|
      * |15 |2017-08-13 00:00:00|2017-08-13 00:00:00|
      * |16 |2017-08-13 00:00:00|2017-08-25 00:00:00|
      * |17 |2017-08-13 00:00:00|2017-09-10 00:00:00|
      * |18 |2017-08-31 00:00:00|2017-09-21 00:00:00|
      * |19 |2017-10-03 00:00:00|2017-09-22 00:00:00|
      * +---+-------------------+-------------------+
      *
      * root
      * |-- id: integer (nullable = true)
      * |-- date_1: timestamp (nullable = true)
      * |-- date_2: timestamp (nullable = true)
      */

计算差异(date_1-date_2)在一周内出现的次数

    // week
    val weekDiff = 7
    val w = Window.orderBy("id", "date_1", "date_2")
      .rangeBetween(Window.currentRow, Window.unboundedFollowing)
    df.withColumn("count", sum(
      when(datediff($"date_1", $"date_2") <= weekDiff, 1).otherwise(0)
    ).over(w))
      .orderBy("id")
      .show(false)

    /**
      * +---+-------------------+-------------------+-----+
      * |id |date_1             |date_2             |count|
      * +---+-------------------+-------------------+-----+
      * |0  |2017-01-21 00:00:00|2017-04-01 00:00:00|19   |
      * |1  |2017-01-22 00:00:00|2017-04-24 00:00:00|18   |
      * |2  |2017-02-23 00:00:00|2017-04-30 00:00:00|17   |
      * |3  |2017-02-27 00:00:00|2017-04-30 00:00:00|16   |
      * |4  |2017-04-23 00:00:00|2017-05-27 00:00:00|15   |
      * |5  |2017-04-29 00:00:00|2017-06-30 00:00:00|14   |
      * |6  |2017-06-13 00:00:00|2017-07-05 00:00:00|13   |
      * |7  |2017-06-13 00:00:00|2017-07-18 00:00:00|12   |
      * |8  |2017-06-16 00:00:00|2017-07-19 00:00:00|11   |
      * |9  |2017-07-09 00:00:00|2017-08-02 00:00:00|10   |
      * |10 |2017-07-18 00:00:00|2017-08-07 00:00:00|9    |
      * |11 |2017-07-28 00:00:00|2017-08-11 00:00:00|8    |
      * |12 |2017-07-28 00:00:00|2017-08-13 00:00:00|7    |
      * |13 |2017-08-04 00:00:00|2017-08-13 00:00:00|6    |
      * |14 |2017-08-13 00:00:00|2017-08-13 00:00:00|5    |
      * |15 |2017-08-13 00:00:00|2017-08-13 00:00:00|4    |
      * |16 |2017-08-13 00:00:00|2017-08-25 00:00:00|3    |
      * |17 |2017-08-13 00:00:00|2017-09-10 00:00:00|2    |
      * |18 |2017-08-31 00:00:00|2017-09-21 00:00:00|1    |
      * |19 |2017-10-03 00:00:00|2017-09-22 00:00:00|0    |
      * +---+-------------------+-------------------+-----+
      */