重塑 pyspark 数据框以显示项目交互的移动 window

Reshape pyspark dataframe to show moving window of item interactions

我有一个大型的 pyspark 主题交互数据框,采用 长格式 -- 每行描述一个主题与一些感兴趣的项目交互,以及时间戳和排名顺序对于该主题的交互(即,第一次交互是 1,第二次是 2,等等)。这里有几行:

+----------+---------+----------------------+--------------------+
|      date|itemId   |interaction_date_order|              userId|
+----------+---------+----------------------+--------------------+
|2019-07-23| 10005880|                     1|37                  |
|2019-07-23| 10005903|                     2|37                  |
|2019-07-23| 10005903|                     3|37                  |
|2019-07-23| 12458442|                     4|37                  |
|2019-07-26| 10005903|                     5|37                  |
|2019-07-26| 12632813|                     6|37                  |
|2019-07-26| 12632813|                     7|37                  |
|2019-07-26| 12634497|                     8|37                  |
|2018-11-24| 12245677|                     1|5                   |
|2018-11-24| 12245677|                     1|5                   |
|2019-07-29| 12541871|                     2|5                   |
|2019-07-29| 12541871|                     3|5                   |
|2019-07-30| 12626854|                     4|5                   |
|2019-08-31| 12776880|                     5|5                   |
|2019-08-31| 12776880|                     6|5                   |
+----------+---------+----------------------+--------------------+

我需要重塑这些数据,使得对于每个主题,一行有长度为 5 的 window 交互。那么,像这样:

+------+--------+--------+--------+--------+--------+
|userId| i-2    |  i-1   |   i    |    i+1 |     i+2|
+------+--------+--------+--------+--------+--------+
|37    |10005880|10005903|10005903|12458442|10005903|
|37    |10005903|10005903|12458442|10005903|12632813|

有人对我如何执行此操作有建议吗?

导入 spark 和一切

from pyspark.sql import *
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
sc = SparkContext('local')
spark = SparkSession(sc)

创建你的数据框

columns = '|      date|itemId   |interaction_date_order|              userId|'.split('|')
lines = '''2019-07-23| 10005880|                     1|37                  |
2019-07-23| 10005903|                     2|37                  |
2019-07-23| 10005903|                     3|37                  |
2019-07-23| 12458442|                     4|37                  |
2019-07-26| 10005903|                     5|37                  |
2019-07-26| 12632813|                     6|37                  |
2019-07-26| 12632813|                     7|37                  |
2019-07-26| 12634497|                     8|37                  |
2018-11-24| 12245677|                     1|5                   |
2018-11-24| 12245677|                     2|5                   |
2019-07-29| 12541871|                     3|5                   |
2019-07-29| 12541871|                     4|5                   |
2019-07-30| 12626854|                     5|5                   |
2019-08-31| 12776880|                     6|5                   |
2019-08-31| 12776880|                     7|5                   |'''

Interaction = Row("date", "itemId", "interaction_date_order", "userId")
interactions = []
for line in lines.split('\n'):
    column_values = line.split('|')
    interaction = Interaction(column_values[0], int(column_values[1]), int(column_values[2]), int(column_values[3]))
    interactions.append(interaction)

df = spark.createDataFrame(interactions)

现在我们有

df.show()

+----------+--------+----------------------+------+
|      date|  itemId|interaction_date_order|userId|
+----------+--------+----------------------+------+
|2019-07-23|10005880|                     1|    37|
|2019-07-23|10005903|                     2|    37|
|2019-07-23|10005903|                     3|    37|
|2019-07-23|12458442|                     4|    37|
|2019-07-26|10005903|                     5|    37|
|2019-07-26|12632813|                     6|    37|
|2019-07-26|12632813|                     7|    37|
|2019-07-26|12634497|                     8|    37|
|2018-11-24|12245677|                     1|     5|
|2018-11-24|12245677|                     2|     5|
|2019-07-29|12541871|                     3|     5|
|2019-07-29|12541871|                     4|     5|
|2019-07-30|12626854|                     5|     5|
|2019-08-31|12776880|                     6|     5|
|2019-08-31|12776880|                     7|     5|
+----------+--------+----------------------+------+

创建一个 window 并使用 count

收集 itemId
from pyspark.sql.window import Window
import pyspark.sql.functions as F

window = Window() \
    .partitionBy('userId') \
    .orderBy('interaction_date_order') \
    .rowsBetween(Window.currentRow, Window.currentRow+4)

df2 = df.withColumn("itemId_list", F.collect_list('itemId').over(window))
df2 = df2.withColumn("itemId_count", F.count('itemId').over(window))
df_final = df2.where(df2['itemId_count'] == 5)

现在我们有

df_final.show()
+----------+--------+----------------------+------+--------------------+------------+
|      date|  itemId|interaction_date_order|userId|         itemId_list|itemId_count|
+----------+--------+----------------------+------+--------------------+------------+
|2018-11-24|12245677|                     1|     5|[12245677, 122456...|           5|
|2018-11-24|12245677|                     2|     5|[12245677, 125418...|           5|
|2019-07-29|12541871|                     3|     5|[12541871, 125418...|           5|
|2019-07-23|10005880|                     1|    37|[10005880, 100059...|           5|
|2019-07-23|10005903|                     2|    37|[10005903, 100059...|           5|
|2019-07-23|10005903|                     3|    37|[10005903, 124584...|           5|
|2019-07-23|12458442|                     4|    37|[12458442, 100059...|           5|
+----------+--------+----------------------+------+--------------------+------------+

最后一笔

df_final2 = (df_final
             .withColumn('i-2', df_final['itemId_list'][0])
             .withColumn('i-1', df_final['itemId_list'][1])
             .withColumn('i', df_final['itemId_list'][2])
             .withColumn('i+1', df_final['itemId_list'][3])
             .withColumn('i+2', df_final['itemId_list'][4])
             .select('userId', 'i-2', 'i-1', 'i', 'i+1', 'i+2')
            )
df_final2.show()
+------+--------+--------+--------+--------+--------+                           
|userId|     i-2|     i-1|       i|     i+1|     i+2|
+------+--------+--------+--------+--------+--------+
|     5|12245677|12245677|12541871|12541871|12626854|
|     5|12245677|12541871|12541871|12626854|12776880|
|     5|12541871|12541871|12626854|12776880|12776880|
|    37|10005880|10005903|10005903|12458442|10005903|
|    37|10005903|10005903|12458442|10005903|12632813|
|    37|10005903|12458442|10005903|12632813|12632813|
|    37|12458442|10005903|12632813|12632813|12634497|
+------+--------+--------+--------+--------+--------+