识别在 PySpark DF ArrayType 列上运行的简洁方法

Clean way to identify runs on a PySpark DF ArrayType column

给定以下形式的 PySpark DataFrame:

+----+--------+
|time|messages|
+----+--------+
| t01|    [m1]|
| t03|[m1, m2]|
| t04|    [m2]|
| t06|    [m3]|
| t07|[m3, m1]|
| t08|    [m1]|
| t11|    [m2]|
| t13|[m2, m4]|
| t15|    [m2]|
| t20|    [m4]|
| t21|      []|
| t22|[m1, m4]|
+----+--------+

我想重构它以压缩包含相同消息的运行(输出的顺序并不重要,但为了清楚起见对她进行了排序):

+----------+--------+-------+
|start_time|end_time|message|
+----------+--------+-------+
|       t01|     t03|     m1|
|       t07|     t08|     m1|
|       t22|     t22|     m1|
|       t03|     t04|     m2|
|       t11|     t15|     m2|
|       t06|     t07|     m3|
|       t13|     t13|     m4|
|       t20|     t20|     m4|
|       t22|     t22|     m4|
+----------+--------+-------+

(即将 message 列视为一个序列,并为每条消息标识 "runs" 的开始和结束),

有没有一种干净的方法可以在 Spark 中进行这种转换?目前,我将其作为 6 GB TSV 转储并强制处理。

如果 Pandas 有一个干净的方法来进行此聚合,我愿意接受 toPandas-ing 并在驱动程序上累积的可能性。

(请参阅 了解简单的基线实施)。

找到了一种合理的方法来执行此操作,如果您可以在应用 window 操作时进行分区(您应该能够在任何真实数据集上进行,我能够在我推导出这个的数据集上进行)问题来自).

为了便于解释,将其分成块(导入仅在第一个片段中)。

设置:

# Need these for the setup
import pandas as pd
from pyspark.sql.types import ArrayType, StringType, StructField, StructType

# We'll need these later
from pyspark.sql.functions import array_except, coalesce, col, explode, from_json, lag, lit, rank
from pyspark.sql.window import Window

rows = [
    ['t01',['m1']],
    ['t03',['m1','m2']],
    ['t04',['m2']],
    ['t06',['m3']],
    ['t07',['m3','m1']],
    ['t08',['m1']],
    ['t11',['m2']],
    ['t13',['m2','m4']],
    ['t15',['m2']],
    ['t20',['m4']],
    ['t21',[]],
    ['t22',['m1','m4']],
]

pdf = pd.DataFrame(rows,columns=['time', 'messages'])
schema = StructType([
    StructField("time", StringType(), True),
    StructField("messages", ArrayType(StringType()), True)
])
df = spark.createDataFrame(pdf,schema=schema)

按时间排序,滞后并生成消息数组的差异以标识运行的开始和结束:

w = Window().partitionBy().orderBy('time')
df2 = df.withColumn('messages_lag_1', lag('messages', 1).over(w))\
        .withColumn('end_time', lag('time', 1).over(w))\
        .withColumnRenamed('time', 'start_time')\
        .withColumn('messages_lag_1',          # Replace nulls with []
            coalesce(                          # cargoculted from
                col('messages_lag_1'),         # 
                from_json(lit('[]'), ArrayType(StringType()))
            )
        )\
        .withColumn('message_run_starts', array_except('messages', 'messages_lag_1'))\
        .withColumn('message_run_ends', array_except('messages_lag_1', 'messages'))\
        .drop(*['messages', 'messages_lag_1']) # ^ only on Spark > 2.4

+----------+--------+------------------+----------------+
|start_time|end_time|message_run_starts|message_run_ends|
+----------+--------+------------------+----------------+
|       t01|    null|              [m1]|              []|
|       t03|     t01|              [m2]|              []|
|       t04|     t03|                []|            [m1]|
|       t06|     t04|              [m3]|            [m2]|
|       t07|     t06|              [m1]|              []|
|       t08|     t07|                []|            [m3]|
|       t11|     t08|              [m2]|            [m1]|
|       t13|     t11|              [m4]|              []|
|       t15|     t13|                []|            [m4]|
|       t20|     t15|              [m4]|            [m2]|
|       t21|     t20|                []|            [m4]|
|       t22|     t21|          [m1, m4]|              []|
+----------+--------+------------------+----------------+

按时间和消息分组,并对开始表和结束表应用排名。加入并在空值的情况下,将 start_time 复制到 end_time:

w_start = Window().partitionBy('message_run_starts').orderBy(col('start_time'))
df3 = df2.withColumn('message_run_starts', explode('message_run_starts')).drop('message_run_ends', 'end_time')
df3 = df3.withColumn('start_row_id',rank().over(w_start))

w_end = Window().partitionBy('message_run_ends').orderBy(col('end_time'))
df4 = df2.withColumn('message_run_ends', explode('message_run_ends')).drop('message_run_starts', 'start_time')
df4 = df4.withColumn('end_row_id',rank().over(w_end))

df_combined = df3\
    .join(df4, (df3.message_run_starts == df4.message_run_ends) & (df3.start_row_id == df4.end_row_id), how='full')\
        .drop(*['message_run_ends','start_row_id','end_row_id'])\
        .withColumn('end_time',coalesce(col('end_time'),col('start_time'))) 

df_combined.show()

+----------+------------------+--------+
|start_time|message_run_starts|end_time|
+----------+------------------+--------+
|       t01|                m1|     t03|
|       t07|                m1|     t08|
|       t22|                m1|     t22|
|       t03|                m2|     t04|
|       t11|                m2|     t15|
|       t06|                m3|     t07|
|       t13|                m4|     t13|
|       t20|                m4|     t20|
|       t22|                m4|     t22|
+----------+------------------+--------+

您可以尝试使用前向填充的以下方法(不需要 Spark 2.4+):

第 1 步:执行以下操作:

  1. 对于按 time 排序的每一行,查找 prev_messagesnext_messages
  2. 条消息 分解为单独的 条消息
  3. 对于每条消息,如果prev_messages为NULL或者消息不在[=35中=]prev_messages,然后设置start=time,见下面SQL语法:

    IF(prev_messages is NULL or !array_contains(prev_messages, message),time,NULL)
    

    可以简化为:

    IF(array_contains(prev_messages, message),NULL,time)
    
  4. 并且如果 next_messages 为 NULL 或 message 不在 next_messages,然后设置end=time

代码如下:

from pyspark.sql import Window, functions as F

# rows is defined in your own post
df = spark.createDataFrame(rows, ['time', 'messages'])

w1 = Window.partitionBy().orderBy('time')

df1 = df.withColumn('prev_messages', F.lag('messages').over(w1)) \
    .withColumn('next_messages', F.lead('messages').over(w1)) \
    .withColumn('message', F.explode('messages')) \
    .withColumn('start', F.expr("IF(array_contains(prev_messages, message),NULL,time)")) \
    .withColumn('end', F.expr("IF(array_contains(next_messages, message),NULL,time)"))

df1.show()
#+----+--------+-------------+-------------+-------+-----+----+
#|time|messages|prev_messages|next_messages|message|start| end|
#+----+--------+-------------+-------------+-------+-----+----+
#| t01|    [m1]|         null|     [m1, m2]|     m1|  t01|null|
#| t03|[m1, m2]|         [m1]|         [m2]|     m1| null| t03|
#| t03|[m1, m2]|         [m1]|         [m2]|     m2|  t03|null|
#| t04|    [m2]|     [m1, m2]|         [m3]|     m2| null| t04|
#| t06|    [m3]|         [m2]|     [m3, m1]|     m3|  t06|null|
#| t07|[m3, m1]|         [m3]|         [m1]|     m3| null| t07|
#| t07|[m3, m1]|         [m3]|         [m1]|     m1|  t07|null|
#| t08|    [m1]|     [m3, m1]|         [m2]|     m1| null| t08|
#| t11|    [m2]|         [m1]|     [m2, m4]|     m2|  t11|null|
#| t13|[m2, m4]|         [m2]|         [m2]|     m2| null|null|
#| t13|[m2, m4]|         [m2]|         [m2]|     m4|  t13| t13|
#| t15|    [m2]|     [m2, m4]|         [m4]|     m2| null| t15|
#| t20|    [m4]|         [m2]|           []|     m4|  t20| t20|
#| t22|[m1, m4]|           []|         null|     m1|  t22| t22|
#| t22|[m1, m4]|           []|         null|     m4|  t22| t22|
#+----+--------+-------------+-------------+-------+-----+----+

第 2 步:创建按消息分区的 WindSpec 并向前填充到 start 列。

w2 = Window.partitionBy('message').orderBy('time')

# for illustration purpose, I used a different column-name so that we can 
# compare `start` column before and after ffill
df2 = df1.withColumn('start_new', F.last('start', True).over(w2))
df2.show()
#+----+--------+-------------+-------------+-------+-----+----+---------+
#|time|messages|prev_messages|next_messages|message|start| end|start_new|
#+----+--------+-------------+-------------+-------+-----+----+---------+
#| t01|    [m1]|         null|     [m1, m2]|     m1|  t01|null|      t01|
#| t03|[m1, m2]|         [m1]|         [m2]|     m1| null| t03|      t01|
#| t07|[m3, m1]|         [m3]|         [m1]|     m1|  t07|null|      t07|
#| t08|    [m1]|     [m3, m1]|         [m2]|     m1| null| t08|      t07|
#| t22|[m1, m4]|           []|         null|     m1|  t22| t22|      t22|
#| t03|[m1, m2]|         [m1]|         [m2]|     m2|  t03|null|      t03|
#| t04|    [m2]|     [m1, m2]|         [m3]|     m2| null| t04|      t03|
#| t11|    [m2]|         [m1]|     [m2, m4]|     m2|  t11|null|      t11|
#| t13|[m2, m4]|         [m2]|         [m2]|     m2| null|null|      t11|
#| t15|    [m2]|     [m2, m4]|         [m4]|     m2| null| t15|      t11|
#| t06|    [m3]|         [m2]|     [m3, m1]|     m3|  t06|null|      t06|
#| t07|[m3, m1]|         [m3]|         [m1]|     m3| null| t07|      t06|
#| t13|[m2, m4]|         [m2]|         [m2]|     m4|  t13| t13|      t13|
#| t20|    [m4]|         [m2]|           []|     m4|  t20| t20|      t20|
#| t22|[m1, m4]|           []|         null|     m4|  t22| t22|      t22|
#+----+--------+-------------+-------------+-------+-----+----+---------+

第 3 步:删除末尾为 NULL 的行,然后 select 只需要列:

df2.selectExpr("message", "start_new as start", "end") \
    .filter("end is not NULL") \
    .orderBy("message","start").show()
#+-------+-----+---+
#|message|start|end|
#+-------+-----+---+
#|     m1|  t01|t03|
#|     m1|  t07|t08|
#|     m1|  t22|t22|
#|     m2|  t03|t04|
#|     m2|  t11|t15|
#|     m3|  t06|t07|
#|     m4|  t13|t13|
#|     m4|  t20|t20|
#|     m4|  t22|t22|
#+-------+-----+---+

总结以上步骤,我们有以下内容:

from pyspark.sql import Window, functions as F

# define two Window Specs
w1 = Window.partitionBy().orderBy('time')
w2 = Window.partitionBy('message').orderBy('time')

df_new = df \
    .withColumn('prev_messages', F.lag('messages').over(w1)) \
    .withColumn('next_messages', F.lead('messages').over(w1)) \
    .withColumn('message', F.explode('messages')) \
    .withColumn('start', F.expr("IF(array_contains(prev_messages, message),NULL,time)")) \
    .withColumn('end', F.expr("IF(array_contains(next_messages, message),NULL,time)")) \
    .withColumn('start', F.last('start', True).over(w2)) \
    .select("message", "start", "end") \
    .filter("end is not NULL")

df_new.orderBy("start").show()

在这里您可以找到 array functions in spark 2.4 的信息,而 explode_outer 是空数组中的展开,将生成具有 'null' 值的行。

思路是先获取每个时刻,开始的消息数组,以及每个时刻结束的消息数组(start_of和end_of)。

然后,我们只保留一条消息开始或结束的时刻,然后创建并分解为具有 3 列的数据框,每条消息开始和结束各一列。在创建 m1 和 m2 的那一刻,将产生 2 个起始行,在 m1 开始和结束的那一刻,将产生 2 行,带有 m1 星号和 m1 结束号。

最后,使用 window 函数按 'message' 分组并按时间排序,确保如果一条消息在同一时刻(同一时间)开始和结束,则开始会先走。现在我们可以保证每次开始后,都会有一个结束行。 混合它们,您将拥有每条消息的开头和结尾。

很好的思考练习。

我已经在 scala 中制作了示例,但应该很容易翻译。标记为 showAndContinue 的每一行都会在该状态下打印您的示例以显示它的作用。

val w = Window.partitionBy().orderBy("time")
val w2 = Window.partitionBy("message").orderBy($"time", desc("start_of"))
df.select($"time", $"messages", lag($"messages", 1).over(w).as("pre"), lag("messages", -1).over(w).as("post"))
  .withColumn("start_of", when($"pre".isNotNull, array_except(col("messages"), col("pre"))).otherwise($"messages"))
  .withColumn("end_of",  when($"post".isNotNull, array_except(col("messages"), col("post"))).otherwise($"messages"))
  .filter(size($"start_of") + size($"end_of") > 0)
  .showAndContinue
  .select(explode(array(
    struct($"time", $"start_of", array().as("end_of")),
    struct($"time", array().as("start_of"), $"end_of")
  )).as("elem"))
  .select("elem.*")
  .select($"time", explode_outer($"start_of").as("start_of"), $"end_of")
  .select( $"time", $"start_of", explode_outer($"end_of").as("end_of"))
  .filter($"start_of".isNotNull || $"end_of".isNotNull)
  .showAndContinue
  .withColumn("message", when($"start_of".isNotNull, $"start_of").otherwise($"end_of"))
  .showAndContinue
  .select($"message", when($"start_of".isNotNull, $"time").as("starts_at"), lag($"time", -1).over(w2).as("ends_at"))
  .filter($"starts_at".isNotNull)
  .showAndContinue

和 tables

+----+--------+--------+--------+--------+--------+
|time|messages|     pre|    post|start_of|  end_of|
+----+--------+--------+--------+--------+--------+
| t01|    [m1]|    null|[m1, m2]|    [m1]|      []|
| t03|[m1, m2]|    [m1]|    [m2]|    [m2]|    [m1]|
| t04|    [m2]|[m1, m2]|    [m3]|      []|    [m2]|
| t06|    [m3]|    [m2]|[m3, m1]|    [m3]|      []|
| t07|[m3, m1]|    [m3]|    [m1]|    [m1]|    [m3]|
| t08|    [m1]|[m3, m1]|    [m2]|      []|    [m1]|
| t11|    [m2]|    [m1]|[m2, m4]|    [m2]|      []|
| t13|[m2, m4]|    [m2]|    [m2]|    [m4]|    [m4]|
| t15|    [m2]|[m2, m4]|    [m4]|      []|    [m2]|
| t20|    [m4]|    [m2]|      []|    [m4]|    [m4]|
| t22|[m1, m4]|      []|    null|[m1, m4]|[m1, m4]|
+----+--------+--------+--------+--------+--------+

+----+--------+------+
|time|start_of|end_of|
+----+--------+------+
| t01|      m1|  null|
| t03|      m2|  null|
| t03|    null|    m1|
| t04|    null|    m2|
| t06|      m3|  null|
| t07|      m1|  null|
| t07|    null|    m3|
| t08|    null|    m1|
| t11|      m2|  null|
| t13|      m4|  null|
| t13|    null|    m4|
| t15|    null|    m2|
| t20|      m4|  null|
| t20|    null|    m4|
| t22|      m1|  null|
| t22|      m4|  null|
| t22|    null|    m1|
| t22|    null|    m4|
+----+--------+------+

+----+--------+------+-------+
|time|start_of|end_of|message|
+----+--------+------+-------+
| t01|      m1|  null|     m1|
| t03|      m2|  null|     m2|
| t03|    null|    m1|     m1|
| t04|    null|    m2|     m2|
| t06|      m3|  null|     m3|
| t07|      m1|  null|     m1|
| t07|    null|    m3|     m3|
| t08|    null|    m1|     m1|
| t11|      m2|  null|     m2|
| t13|      m4|  null|     m4|
| t13|    null|    m4|     m4|
| t15|    null|    m2|     m2|
| t20|      m4|  null|     m4|
| t20|    null|    m4|     m4|
| t22|      m1|  null|     m1|
| t22|      m4|  null|     m4|
| t22|    null|    m1|     m1|
| t22|    null|    m4|     m4|
+----+--------+------+-------+

+-------+---------+-------+
|message|starts_at|ends_at|
+-------+---------+-------+
|     m1|      t01|    t03|
|     m1|      t07|    t08|
|     m1|      t22|    t22|
|     m2|      t03|    t04|
|     m2|      t11|    t15|
|     m3|      t06|    t07|
|     m4|      t13|    t13|
|     m4|      t20|    t20|
|     m4|      t22|    t22|
+-------+---------+-------+

可以优化提取在同一时刻开始和结束的所有元素,在第一个 table 创建时,因此它们不必再次成为 "matched" 的开始和结束,但这取决于这是普遍情况,还是只是少数情况。 优化后会这样(同windows)

val dfStartEndAndFiniteLife = df.select($"time", $"messages", lag($"messages", 1).over(w).as("pre"), lag("messages", -1).over(w).as("post"))
  .withColumn("start_of", when($"pre".isNotNull, array_except(col("messages"), col("pre"))).otherwise($"messages"))
  .withColumn("end_of",  when($"post".isNotNull, array_except(col("messages"), col("post"))).otherwise($"messages"))
  .filter(size($"start_of") + size($"end_of") > 0)
  .withColumn("start_end_here", array_intersect($"start_of", $"end_of"))
  .withColumn("start_of", array_except($"start_of", $"start_end_here"))
  .withColumn("end_of", array_except($"end_of", $"start_end_here"))
  .showAndContinue

val onlyStartEndSameMoment = dfStartEndAndFiniteLife.filter(size($"start_end_here") > 0)
  .select(explode($"start_end_here"), $"time".as("starts_at"), $"time".as("ends_at"))
  .showAndContinue

val startEndDifferentMoment = dfStartEndAndFiniteLife
  .filter(size($"start_of") + size($"end_of") > 0)
  .showAndContinue
  .select(explode(array(
    struct($"time", $"start_of", array().as("end_of")),
    struct($"time", array().as("start_of"), $"end_of")
  )).as("elem"))
  .select("elem.*")
  .select($"time", explode_outer($"start_of").as("start_of"), $"end_of")
  .select( $"time", $"start_of", explode_outer($"end_of").as("end_of"))
  .filter($"start_of".isNotNull || $"end_of".isNotNull)
  .showAndContinue
  .withColumn("message", when($"start_of".isNotNull, $"start_of").otherwise($"end_of"))
  .showAndContinue
  .select($"message", when($"start_of".isNotNull, $"time").as("starts_at"), lag($"time", -1).over(w2).as("ends_at"))
  .filter($"starts_at".isNotNull)
  .showAndContinue

val result = onlyStartEndSameMoment.union(startEndDifferentMoment)

result.orderBy("col", "starts_at").show()

和 tables

+----+--------+--------+--------+--------+------+--------------+
|time|messages|     pre|    post|start_of|end_of|start_end_here|
+----+--------+--------+--------+--------+------+--------------+
| t01|    [m1]|    null|[m1, m2]|    [m1]|    []|            []|
| t03|[m1, m2]|    [m1]|    [m2]|    [m2]|  [m1]|            []|
| t04|    [m2]|[m1, m2]|    [m3]|      []|  [m2]|            []|
| t06|    [m3]|    [m2]|[m3, m1]|    [m3]|    []|            []|
| t07|[m3, m1]|    [m3]|    [m1]|    [m1]|  [m3]|            []|
| t08|    [m1]|[m3, m1]|    [m2]|      []|  [m1]|            []|
| t11|    [m2]|    [m1]|[m2, m4]|    [m2]|    []|            []|
| t13|[m2, m4]|    [m2]|    [m2]|      []|    []|          [m4]|
| t15|    [m2]|[m2, m4]|    [m4]|      []|  [m2]|            []|
| t20|    [m4]|    [m2]|      []|      []|    []|          [m4]|
| t22|[m1, m4]|      []|    null|      []|    []|      [m1, m4]|
+----+--------+--------+--------+--------+------+--------------+

+---+---------+-------+
|col|starts_at|ends_at|
+---+---------+-------+
| m4|      t13|    t13|
| m4|      t20|    t20|
| m1|      t22|    t22|
| m4|      t22|    t22|
+---+---------+-------+

+----+--------+--------+--------+--------+------+--------------+
|time|messages|     pre|    post|start_of|end_of|start_end_here|
+----+--------+--------+--------+--------+------+--------------+
| t01|    [m1]|    null|[m1, m2]|    [m1]|    []|            []|
| t03|[m1, m2]|    [m1]|    [m2]|    [m2]|  [m1]|            []|
| t04|    [m2]|[m1, m2]|    [m3]|      []|  [m2]|            []|
| t06|    [m3]|    [m2]|[m3, m1]|    [m3]|    []|            []|
| t07|[m3, m1]|    [m3]|    [m1]|    [m1]|  [m3]|            []|
| t08|    [m1]|[m3, m1]|    [m2]|      []|  [m1]|            []|
| t11|    [m2]|    [m1]|[m2, m4]|    [m2]|    []|            []|
| t15|    [m2]|[m2, m4]|    [m4]|      []|  [m2]|            []|
+----+--------+--------+--------+--------+------+--------------+

+----+--------+------+
|time|start_of|end_of|
+----+--------+------+
| t01|      m1|  null|
| t03|      m2|  null|
| t03|    null|    m1|
| t04|    null|    m2|
| t06|      m3|  null|
| t07|      m1|  null|
| t07|    null|    m3|
| t08|    null|    m1|
| t11|      m2|  null|
| t15|    null|    m2|
+----+--------+------+

+----+--------+------+-------+
|time|start_of|end_of|message|
+----+--------+------+-------+
| t01|      m1|  null|     m1|
| t03|      m2|  null|     m2|
| t03|    null|    m1|     m1|
| t04|    null|    m2|     m2|
| t06|      m3|  null|     m3|
| t07|      m1|  null|     m1|
| t07|    null|    m3|     m3|
| t08|    null|    m1|     m1|
| t11|      m2|  null|     m2|
| t15|    null|    m2|     m2|
+----+--------+------+-------+

+-------+---------+-------+
|message|starts_at|ends_at|
+-------+---------+-------+
|     m1|      t01|    t03|
|     m1|      t07|    t08|
|     m2|      t03|    t04|
|     m2|      t11|    t15|
|     m3|      t06|    t07|
+-------+---------+-------+

+---+---------+-------+
|col|starts_at|ends_at|
+---+---------+-------+
| m1|      t01|    t03|
| m1|      t07|    t08|
| m1|      t22|    t22|
| m2|      t03|    t04|
| m2|      t11|    t15|
| m3|      t06|    t07|
| m4|      t13|    t13|
| m4|      t20|    t20|
| m4|      t22|    t22|
+---+---------+-------+