如何在 PySpark DataFrame 中创建 N 个重复行?

How to create N duplicated rows in PySpark DataFrame?

我有以下 PySpark DataFrame df:

itemid  eventid    timestamp     timestamp_end   n
134     30         2016-07-02    2016-07-09      2
134     32         2016-07-03    2016-07-10      2
125     32         2016-07-10    2016-07-17      1

我想将此 DataFrame 转换为以下 DataFrame:

itemid  eventid    timestamp_start   timestamp     timestamp_end
134     30         2016-07-02        2016-07-02    2016-07-09
134     32         2016-07-02        2016-07-03    2016-07-09
134     30         2016-07-03        2016-07-02    2016-07-10
134     32         2016-07-03        2016-07-03    2016-07-10
125     32         2016-07-10        2016-07-10    2016-07-17

基本上,对于 itemid 的每个唯一值,我需要取 timestamp 并将其放入新列 timestamp_start。因此,itemid 组中的每一行都应重复 n 次,其中 n 是组中的记录数。希望我解释清楚了。

这是我在 PySpark 中的初始 DataFrame:

from pyspark.sql.functions import col, expr

df = (
    sc.parallelize([
        (134, 30, "2016-07-02", "2016-07-09"), (134, 32, "2016-07-03", "2016-07-10"),
        (125, 32, "2016-07-10", "2016-07-17"),
    ]).toDF(["itemid", "eventid", "timestamp", "timestamp_end"])
    .withColumn("timestamp", col("timestamp").cast("timestamp"))
    .withColumn("timestamp_end", col("timestamp_end").cast("timestamp_end"))
)

到目前为止,我设法复制了行 n 次:

new_df = df.withColumn("n", expr("explode(array_repeat(n,int(n)))"))

但是如何创建如上例所示的 timestamp_start

谢谢。

IIUC,可以用Window函数collect_list找出一组所有timestamp+timestamp_end的列表,然后用SparkSQL内置函数inline/inline_outer爆破结果结构数组:

from pyspark.sql.functions import collect_list, expr
from pyspark.sql import Window

w1 = Window.partitionBy('itemid')

df.withColumn('timestamp_range',  
    collect_list(expr("(timestamp as timestamp_start, timestamp_end)")).over(w1)
 ).selectExpr(
    'itemid',  
    'eventid', 
    'timestamp', 
    'inline_outer(timestamp_range)'
 ).show()    
+------+-------+----------+---------------+-------------+
|itemid|eventid| timestamp|timestamp_start|timestamp_end|
+------+-------+----------+---------------+-------------+
|   134|     30|2016-07-02|     2016-07-02|   2016-07-09|
|   134|     30|2016-07-02|     2016-07-03|   2016-07-10|
|   134|     32|2016-07-03|     2016-07-02|   2016-07-09|
|   134|     32|2016-07-03|     2016-07-03|   2016-07-10|
|   125|     32|2016-07-10|     2016-07-10|   2016-07-17|
+------+-------+----------+---------------+-------------+

其中: timestamp_range 是以下 named_struct(在 SparkSQL 语法中)的 collect_list:

(timestamp as timestamp_start, timestamp_end)

与以下相同:

named_struct('timestamp_start', timestamp, 'timestamp_end', timestamp_end)