pyspark select 在某些情况下第一个元素超过 window

pyspark select first element over window on some condition

问题

您好,在某些情况下,pyspark/spark 中是否有方法 select 第一个元素超过某些 window?

例子

让我们来看一个示例输入数据框

+---------+----------+----+----+----------------+
|       id| timestamp|  f1|  f2|        computed|
+---------+----------+----+----+----------------+
|        1|2020-01-02|null|c1f2|            [f2]|
|        1|2020-01-01|c1f1|null|            [f1]|
|        2|2020-01-01|c2f1|null|            [f1]|
+---------+----------+----+----+----------------+

我想 select 为计算的每个 ID 最新列(f1、f2...)。

所以“代码”看起来像这样

cols = ["f1", "f2"]

w = Window().partitionBy("id").orderBy(f.desc("timestamp")).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

output_df = (
    input_df.select(
        "id",
        *[f.first(col, condition=f.array_contains(f.col("computed"), col)).over(w).alias(col) for col in cols]
    )
    .groupBy("id")
    .agg(*[f.first(col).alias(col) for col in cols])
)

输出应该是

+---------+----+----+
|       id|  f1|  f2|
+---------+----+----+
|        1|c1f1|c1f2|
|        2|c2f1|null|
+---------+----+----+

如果输入是这样的

+---------+----------+----+----+----------------+
|       id| timestamp|  f1|  f2|        computed|
+---------+----------+----+----+----------------+
|        1|2020-01-02|null|c1f2|        [f1, f2]|
|        1|2020-01-01|c1f1|null|            [f1]|
|        2|2020-01-01|c2f1|null|            [f1]|
+---------+----------+----+----+----------------+

那么输出应该是

+---------+----+----+
|       id|  f1|  f2|
+---------+----+----+
|        1|null|c1f2|
|        2|c2f1|null|
+---------+----+----+

如您所见,仅使用 f.first(ignore_nulls=True) 并不容易,因为在这种情况下我们不想跳过 null,因为它被视为计算值。

当前解决方案

步骤 1

保存原始数据类型

cols = ["f1", "f2"]
orig_dtypes = [field.dataType for field in input_df.schema if field.name in cols]

第 2 步

对于每一列,如果计算了该列,则用它的值创建新列,并将原始空值替换为我们的“合成”<NULL> 字符串

output_df = input_df.select(
    "id", "timestamp", "computed",
    *[
        f.when(f.array_contains(f.col("computed"), col) & f.col(col).isNotNull(), f.col(col))
        .when(f.array_contains(f.col("computed"), col) & f.col(col).isNull(), "<NULL>")
        .alias(col)
        for col in cols
    ]
)

步骤 3

Select 第一个非空值超过 window 因为现在我们知道 <NULL> 不会被跳过

output_df = (
    output_df.select(
        "id",
        *[f.first(col, ignorenulls=True).over(w).alias(col) for col in cols],
    )
    .groupBy("id")
    .agg(*[f.first(col).alias(col) for col in cols])
)

步骤 4

将我们的“合成”<NULL> 替换为原始空值。

output_df = output_df.replace("<NULL>", None)

第 5 步

将列转换回其原始类型,因为它们可能会在步骤 2 中重新键入为字符串

output_df = output_df.select("id", *[f.col(col).cast(type_) for col, type_ in zip(cols, orig_dtypes)])

此解决方案有效,但似乎不是正确的方法。此外,它非常重,而且计算时间太长。

还有其他更“闪亮”的方法吗?

这是使用结构排序技巧的一种方法。

Groupby id 并为 cols 列表中的每一列收集结构列表,例如 struct<col_exists_in_computed, timestamp, col_value>,然后在结果数组上使用 array_max 函数,您可以获得最后的值你想要:

from pyspark.sql import functions as F

output_df = input_df.groupBy("id").agg(
    *[F.array_max(
        F.collect_list(
          F.struct(F.array_contains("computed", c), F.col("timestamp"), F.col(c))
        )
    )[c].alias(c) for c in cols]
)

# applied to you second dataframe example, it gives

output_df.show()
#+---+----+----+
#| id|  f1|  f2|
#+---+----+----+
#|  1|null|c1f2|
#|  2|c2f1|null|
#+---+----+----+