识别与 PySpark 数据框中的当前值不同的最新记录

identify the most recent record different from the current value in a PySpark dataframe

我有一个 PySpark 数据框,其中每个用户在某个时间点都有特定的状态,如下面的虚拟数据

    --------------------------
    |user_id| status| month  |
    --------------------------
    | 1     | A     | 12/2020|
    | 1     | A     | 11/2020|
    | 1     | B     | 10/2020|
    | 1     | B     | 09/2020|
    | 1     | A     | 08/2020|
    | 1     | C     | 07/2020|
    | 2     | A     | 12/2020|
    | 2     | A     | 11/2020|
    | 2     | A     | 10/2020|
    | 2     | B     | 09/2020|

我想在我的 PySpark 数据框中创建另外两列(previous_status_value 和 previous_status_month),其中每条记录指示用户状态不同的最近日期记录中的那个,那个值是多少。使用上述虚拟数据,结果将是

    ------------------------------------------------------------------------
    |user_id| status| month  | previous_status_value| previous_status_month|
    ------------------------------------------------------------------------
    | 1     | A     | 12/2020| B                    | 10/2020              |
    | 1     | A     | 11/2020| B                    | 10/2020              |
    | 1     | B     | 10/2020| A                    | 08/2020              |
    | 1     | B     | 09/2020| A                    | 08/2020              |
    | 1     | A     | 08/2020| C                    | 07/2020              |
    | 1     | C     | 07/2020| Null                 | Null                 |
    | 2     | A     | 12/2020| B                    | 09/2020              |
    | 2     | A     | 11/2020| B                    | 09/2020              |
    | 2     | A     | 10/2020| B                    | 09/2020              |
    | 2     | B     | 09/2020| Null                 | Null                 |

数据框有数百万条记录,所以我试图使用 Window 函数(类似于 )来解决这个问题,但没有成功。

使用lead查找状态变化的地方,只保留状态变化对应的statusmonth,并用null掩码,否则使用when(F.col('begin'), F.col('status')),得到使用 F.last(..., ignorenulls=True).

的先前值
import pyspark.sql.functions as F
from pyspark.sql.window import Window

w = Window.partitionBy('user_id').orderBy('month')
begin = F.lead('status').over(w) != F.col('status')
df = df.select('*', begin.alias('begin'))

w = w.rowsBetween(Window.unboundedPreceding, -1)
previous_status_value = F.last(F.when(F.col('begin'), F.col('status')), ignorenulls=True).over(w).alias('previous_status_value')
previous_status_month = F.last(F.when(F.col('begin'), F.col('month')), ignorenulls=True).over(w).alias('previous_status_month ')

df = df.select('*', previous_status_value, previous_status_month).drop('begin').orderBy('user_id', F.col('month').desc())

df.show()
+-------+------+-------+---------------------+----------------------+
|user_id|status|  month|previous_status_value|previous_status_month |
+-------+------+-------+---------------------+----------------------+
|      1|     A|12/2020|                    B|               10/2020|
|      1|     A|11/2020|                    B|               10/2020|
|      1|     B|10/2020|                    A|               08/2020|
|      1|     B|09/2020|                    A|               08/2020|
|      1|     A|08/2020|                    C|               07/2020|
|      1|     C|07/2020|                 null|                  null|
|      2|     A|12/2020|                    B|               09/2020|
|      2|     A|11/2020|                    B|               09/2020|
|      2|     A|10/2020|                    B|               09/2020|
|      2|     B|09/2020|                 null|                  null|
+-------+------+-------+---------------------+----------------------+