pyspark select 在某些情况下第一个元素超过 window
pyspark select first element over window on some condition
问题
您好,在某些情况下,pyspark/spark 中是否有方法 select 第一个元素超过某些 window?
例子
让我们来看一个示例输入数据框
+---------+----------+----+----+----------------+
| id| timestamp| f1| f2| computed|
+---------+----------+----+----+----------------+
| 1|2020-01-02|null|c1f2| [f2]|
| 1|2020-01-01|c1f1|null| [f1]|
| 2|2020-01-01|c2f1|null| [f1]|
+---------+----------+----+----+----------------+
我想 select 为计算的每个 ID 最新列(f1、f2...)。
所以“代码”看起来像这样
cols = ["f1", "f2"]
w = Window().partitionBy("id").orderBy(f.desc("timestamp")).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
output_df = (
input_df.select(
"id",
*[f.first(col, condition=f.array_contains(f.col("computed"), col)).over(w).alias(col) for col in cols]
)
.groupBy("id")
.agg(*[f.first(col).alias(col) for col in cols])
)
输出应该是
+---------+----+----+
| id| f1| f2|
+---------+----+----+
| 1|c1f1|c1f2|
| 2|c2f1|null|
+---------+----+----+
如果输入是这样的
+---------+----------+----+----+----------------+
| id| timestamp| f1| f2| computed|
+---------+----------+----+----+----------------+
| 1|2020-01-02|null|c1f2| [f1, f2]|
| 1|2020-01-01|c1f1|null| [f1]|
| 2|2020-01-01|c2f1|null| [f1]|
+---------+----------+----+----+----------------+
那么输出应该是
+---------+----+----+
| id| f1| f2|
+---------+----+----+
| 1|null|c1f2|
| 2|c2f1|null|
+---------+----+----+
如您所见,仅使用 f.first(ignore_nulls=True)
并不容易,因为在这种情况下我们不想跳过 null,因为它被视为计算值。
当前解决方案
步骤 1
保存原始数据类型
cols = ["f1", "f2"]
orig_dtypes = [field.dataType for field in input_df.schema if field.name in cols]
第 2 步
对于每一列,如果计算了该列,则用它的值创建新列,并将原始空值替换为我们的“合成”<NULL>
字符串
output_df = input_df.select(
"id", "timestamp", "computed",
*[
f.when(f.array_contains(f.col("computed"), col) & f.col(col).isNotNull(), f.col(col))
.when(f.array_contains(f.col("computed"), col) & f.col(col).isNull(), "<NULL>")
.alias(col)
for col in cols
]
)
步骤 3
Select 第一个非空值超过 window 因为现在我们知道 <NULL>
不会被跳过
output_df = (
output_df.select(
"id",
*[f.first(col, ignorenulls=True).over(w).alias(col) for col in cols],
)
.groupBy("id")
.agg(*[f.first(col).alias(col) for col in cols])
)
步骤 4
将我们的“合成”<NULL>
替换为原始空值。
output_df = output_df.replace("<NULL>", None)
第 5 步
将列转换回其原始类型,因为它们可能会在步骤 2 中重新键入为字符串
output_df = output_df.select("id", *[f.col(col).cast(type_) for col, type_ in zip(cols, orig_dtypes)])
此解决方案有效,但似乎不是正确的方法。此外,它非常重,而且计算时间太长。
还有其他更“闪亮”的方法吗?
这是使用结构排序技巧的一种方法。
Groupby id
并为 cols
列表中的每一列收集结构列表,例如 struct<col_exists_in_computed, timestamp, col_value>
,然后在结果数组上使用 array_max
函数,您可以获得最后的值你想要:
from pyspark.sql import functions as F
output_df = input_df.groupBy("id").agg(
*[F.array_max(
F.collect_list(
F.struct(F.array_contains("computed", c), F.col("timestamp"), F.col(c))
)
)[c].alias(c) for c in cols]
)
# applied to you second dataframe example, it gives
output_df.show()
#+---+----+----+
#| id| f1| f2|
#+---+----+----+
#| 1|null|c1f2|
#| 2|c2f1|null|
#+---+----+----+
问题
您好,在某些情况下,pyspark/spark 中是否有方法 select 第一个元素超过某些 window?
例子
让我们来看一个示例输入数据框
+---------+----------+----+----+----------------+
| id| timestamp| f1| f2| computed|
+---------+----------+----+----+----------------+
| 1|2020-01-02|null|c1f2| [f2]|
| 1|2020-01-01|c1f1|null| [f1]|
| 2|2020-01-01|c2f1|null| [f1]|
+---------+----------+----+----+----------------+
我想 select 为计算的每个 ID 最新列(f1、f2...)。
所以“代码”看起来像这样
cols = ["f1", "f2"]
w = Window().partitionBy("id").orderBy(f.desc("timestamp")).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
output_df = (
input_df.select(
"id",
*[f.first(col, condition=f.array_contains(f.col("computed"), col)).over(w).alias(col) for col in cols]
)
.groupBy("id")
.agg(*[f.first(col).alias(col) for col in cols])
)
输出应该是
+---------+----+----+
| id| f1| f2|
+---------+----+----+
| 1|c1f1|c1f2|
| 2|c2f1|null|
+---------+----+----+
如果输入是这样的
+---------+----------+----+----+----------------+
| id| timestamp| f1| f2| computed|
+---------+----------+----+----+----------------+
| 1|2020-01-02|null|c1f2| [f1, f2]|
| 1|2020-01-01|c1f1|null| [f1]|
| 2|2020-01-01|c2f1|null| [f1]|
+---------+----------+----+----+----------------+
那么输出应该是
+---------+----+----+
| id| f1| f2|
+---------+----+----+
| 1|null|c1f2|
| 2|c2f1|null|
+---------+----+----+
如您所见,仅使用 f.first(ignore_nulls=True)
并不容易,因为在这种情况下我们不想跳过 null,因为它被视为计算值。
当前解决方案
步骤 1
保存原始数据类型
cols = ["f1", "f2"]
orig_dtypes = [field.dataType for field in input_df.schema if field.name in cols]
第 2 步
对于每一列,如果计算了该列,则用它的值创建新列,并将原始空值替换为我们的“合成”<NULL>
字符串
output_df = input_df.select(
"id", "timestamp", "computed",
*[
f.when(f.array_contains(f.col("computed"), col) & f.col(col).isNotNull(), f.col(col))
.when(f.array_contains(f.col("computed"), col) & f.col(col).isNull(), "<NULL>")
.alias(col)
for col in cols
]
)
步骤 3
Select 第一个非空值超过 window 因为现在我们知道 <NULL>
不会被跳过
output_df = (
output_df.select(
"id",
*[f.first(col, ignorenulls=True).over(w).alias(col) for col in cols],
)
.groupBy("id")
.agg(*[f.first(col).alias(col) for col in cols])
)
步骤 4
将我们的“合成”<NULL>
替换为原始空值。
output_df = output_df.replace("<NULL>", None)
第 5 步
将列转换回其原始类型,因为它们可能会在步骤 2 中重新键入为字符串
output_df = output_df.select("id", *[f.col(col).cast(type_) for col, type_ in zip(cols, orig_dtypes)])
此解决方案有效,但似乎不是正确的方法。此外,它非常重,而且计算时间太长。
还有其他更“闪亮”的方法吗?
这是使用结构排序技巧的一种方法。
Groupby id
并为 cols
列表中的每一列收集结构列表,例如 struct<col_exists_in_computed, timestamp, col_value>
,然后在结果数组上使用 array_max
函数,您可以获得最后的值你想要:
from pyspark.sql import functions as F
output_df = input_df.groupBy("id").agg(
*[F.array_max(
F.collect_list(
F.struct(F.array_contains("computed", c), F.col("timestamp"), F.col(c))
)
)[c].alias(c) for c in cols]
)
# applied to you second dataframe example, it gives
output_df.show()
#+---+----+----+
#| id| f1| f2|
#+---+----+----+
#| 1|null|c1f2|
#| 2|c2f1|null|
#+---+----+----+