Pyspark:检查列值是否单调递增
Pyspark: Check if column values are monotonically increasing
问题: 给定下面的 pyspark 数据框,是否可以按行检查“some_value”是否确实增加了(与前一行相比)使用 window 函数(参见下面的示例)?
没有滞后的解决方案是首选,因为我将有多个列,如“some_value”,我事先不知道有多少以及它们的明确名称。
示例: 在这里,我想获得像“FLAG_INCREASE”这样的列。
+---+----------+---+----------+
| id| datum|lfd|some_value| FLAG_INCREASE
+---+----------+---+----------+ ------------+
| 1|2015-01-01| 4| 20.0| 0
| 1|2015-01-06| 3| 10.0| 0
| 1|2015-01-07| 2| 25.0| 1
| 1|2015-01-12| 1| 30.0| 1
| 2|2015-01-01| 4| 5.0| 0
| 2|2015-01-06| 3| 30.0| 1
| 2|2015-01-12| 1| 20.0| 0
+---+----------+---+----------+--------------+
代码:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql import Row
row = Row("id", "datum", "lfd", "some_value", "some_value2")
df = spark.sparkContext.parallelize([
row(1, "2015-01-01", 4, 20.0, 20.0),
row(1, "2015-01-06", 3, 10.0, 20.0),
row(1, "2015-01-07", 2, 25.0, 20.0),
row(1, "2015-01-12", 1, 30.0, 20.0),
row(2, "2015-01-01", 4, 5.0, 20.0),
row(2, "2015-01-06", 3, 30.0, 20.0),
row(2, "2015-01-12", 1, 20.0, 20.0)
]).toDF().withColumn("datum", F.col("datum").cast("date"))
+---+----------+---+----------+
| id| datum|lfd|some_value|
+---+----------+---+----------+
| 1|2015-01-01| 4| 20.0|
| 1|2015-01-06| 3| 10.0|
| 1|2015-01-07| 2| 25.0|
| 1|2015-01-12| 1| 30.0|
| 2|2015-01-01| 4| 5.0|
| 2|2015-01-06| 3| 30.0|
| 2|2015-01-12| 1| 20.0|
+---+----------+---+----------+
你只需要 lag
:
from pyspark.sql import functions as F, Window
df = df.withColumn(
"FLAG_INCREASE",
F.when(
F.col("some_value")
> F.lag("some_value").over(Window.partitionBy("id").orderBy("datum")),
1,
).otherwise(0),
)
问题: 给定下面的 pyspark 数据框,是否可以按行检查“some_value”是否确实增加了(与前一行相比)使用 window 函数(参见下面的示例)?
没有滞后的解决方案是首选,因为我将有多个列,如“some_value”,我事先不知道有多少以及它们的明确名称。
示例: 在这里,我想获得像“FLAG_INCREASE”这样的列。
+---+----------+---+----------+
| id| datum|lfd|some_value| FLAG_INCREASE
+---+----------+---+----------+ ------------+
| 1|2015-01-01| 4| 20.0| 0
| 1|2015-01-06| 3| 10.0| 0
| 1|2015-01-07| 2| 25.0| 1
| 1|2015-01-12| 1| 30.0| 1
| 2|2015-01-01| 4| 5.0| 0
| 2|2015-01-06| 3| 30.0| 1
| 2|2015-01-12| 1| 20.0| 0
+---+----------+---+----------+--------------+
代码:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql import Row
row = Row("id", "datum", "lfd", "some_value", "some_value2")
df = spark.sparkContext.parallelize([
row(1, "2015-01-01", 4, 20.0, 20.0),
row(1, "2015-01-06", 3, 10.0, 20.0),
row(1, "2015-01-07", 2, 25.0, 20.0),
row(1, "2015-01-12", 1, 30.0, 20.0),
row(2, "2015-01-01", 4, 5.0, 20.0),
row(2, "2015-01-06", 3, 30.0, 20.0),
row(2, "2015-01-12", 1, 20.0, 20.0)
]).toDF().withColumn("datum", F.col("datum").cast("date"))
+---+----------+---+----------+
| id| datum|lfd|some_value|
+---+----------+---+----------+
| 1|2015-01-01| 4| 20.0|
| 1|2015-01-06| 3| 10.0|
| 1|2015-01-07| 2| 25.0|
| 1|2015-01-12| 1| 30.0|
| 2|2015-01-01| 4| 5.0|
| 2|2015-01-06| 3| 30.0|
| 2|2015-01-12| 1| 20.0|
+---+----------+---+----------+
你只需要 lag
:
from pyspark.sql import functions as F, Window
df = df.withColumn(
"FLAG_INCREASE",
F.when(
F.col("some_value")
> F.lag("some_value").over(Window.partitionBy("id").orderBy("datum")),
1,
).otherwise(0),
)