PySpark 数据帧条件 window/lag
PySpark dataframe condition by window/lag
我有一个 Spark 数据框,像这样:
# For sake of simplicity only one user (uid) is shown, but there are multiple users
+-------------------+-----+-------+
|start_date |uid |count |
+-------------------+-----+-------+
|2020-11-26 08:30:22|user1| 4 |
|2020-11-26 10:00:00|user1| 3 |
|2020-11-22 08:37:18|user1| 3 |
|2020-11-22 13:32:30|user1| 2 |
|2020-11-20 16:04:04|user1| 2 |
|2020-11-16 12:04:04|user1| 1 |
如果用户过去至少有 count >= x 个事件,我想创建一个值为 True/False 的新布尔列,和用True标记这些事件。例如,对于 x=3 我希望得到:
+-------------------+-----+-------+--------------+
|start_date |uid |count | marked_event |
+-------------------+-----+-------+--------------+
|2020-11-26 08:30:22|user1| 4 | True |
|2020-11-26 10:00:00|user1| 3 | True |
|2020-11-22 08:37:18|user1| 3 | True |
|2020-11-22 13:32:30|user1| 2 | True |
|2020-11-20 16:04:04|user1| 2 | True |
|2020-11-16 12:04:04|user1| 1 | False |
也就是说,对于每个 >= 3 的计数,我需要用 True 标记该事件,以及之前的 3 个事件。只有 user1 的最后一个事件是 False,因为我在 start_date = 2020-11-22 08:37:18.
上的事件之前(包括)标记了 3 个事件
有什么想法可以解决这个问题吗?我的直觉是以某种方式使用 window/lag 来实现这一点,但我是新手,不知道该怎么做...
编辑:
我结束了使用@mck 解决方案的变体,修复了一个小错误:原始解决方案有:
F.max(F.col('begin')).over(w.rowsBetween(0, Window.unboundedFollowing))
条件,最终标记所有事件在'begin'之后,无论'count'的条件是否满足。相反,我更改了解决方案,以便 window 仅标记 'begin':
之前发生的事件
event = (f.max(f.col('begin')).over(w.rowsBetween(-2, 0))).\
alias('event_post_only')
# the number of events to mark is 3 from 'begin',
# including the event itself, so that's -2.
df_marked_events = df_marked_events.select('*', event)
然后为在 'event_post_only' 中为真或在 'event_post_only'
中为真的所有事件标记为真
df_marked_events = df_marked_events.withColumn('event', (col('count') >= 3) \
| (col('event_post_only')))
这避免将 上游的所有内容 标记为 True 到 'begin' == True
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w = Window.partitionBy('uid').orderBy(F.col('count').desc(), F.col('start_date'))
# find the beginning point of >= 3 events
begin = (
(F.col('count') >= 3) &
(F.lead(F.col('count')).over(w) < 3)
).alias('begin')
df = df.select('*', begin)
# Mark as event if the event is in any rows after begin, or two rows before begin
event = (
F.max(F.col('begin')).over(w.rowsBetween(0, Window.unboundedFollowing)) |
F.max(F.col('begin')).over(w.rowsBetween(-2,0))
).alias('event')
df = df.select('*', event)
df.show()
+-------------------+-----+-----+-----+-----+
| start_date| uid|count|begin|event|
+-------------------+-----+-----+-----+-----+
|2020-11-26 08:30:22|user1| 4.0|false| true|
|2020-11-22 08:37:18|user1| 3.0|false| true|
|2020-11-26 10:00:00|user1| 3.0| true| true|
|2020-11-20 16:04:04|user1| 2.0|false| true|
|2020-11-22 13:32:30|user1| 2.0|false| true|
|2020-11-16 12:04:04|user1| 1.0|false|false|
+-------------------+-----+-----+-----+-----+
我有一个 Spark 数据框,像这样:
# For sake of simplicity only one user (uid) is shown, but there are multiple users
+-------------------+-----+-------+
|start_date |uid |count |
+-------------------+-----+-------+
|2020-11-26 08:30:22|user1| 4 |
|2020-11-26 10:00:00|user1| 3 |
|2020-11-22 08:37:18|user1| 3 |
|2020-11-22 13:32:30|user1| 2 |
|2020-11-20 16:04:04|user1| 2 |
|2020-11-16 12:04:04|user1| 1 |
如果用户过去至少有 count >= x 个事件,我想创建一个值为 True/False 的新布尔列,和用True标记这些事件。例如,对于 x=3 我希望得到:
+-------------------+-----+-------+--------------+
|start_date |uid |count | marked_event |
+-------------------+-----+-------+--------------+
|2020-11-26 08:30:22|user1| 4 | True |
|2020-11-26 10:00:00|user1| 3 | True |
|2020-11-22 08:37:18|user1| 3 | True |
|2020-11-22 13:32:30|user1| 2 | True |
|2020-11-20 16:04:04|user1| 2 | True |
|2020-11-16 12:04:04|user1| 1 | False |
也就是说,对于每个 >= 3 的计数,我需要用 True 标记该事件,以及之前的 3 个事件。只有 user1 的最后一个事件是 False,因为我在 start_date = 2020-11-22 08:37:18.
上的事件之前(包括)标记了 3 个事件有什么想法可以解决这个问题吗?我的直觉是以某种方式使用 window/lag 来实现这一点,但我是新手,不知道该怎么做...
编辑:
我结束了使用@mck 解决方案的变体,修复了一个小错误:原始解决方案有:
F.max(F.col('begin')).over(w.rowsBetween(0, Window.unboundedFollowing))
条件,最终标记所有事件在'begin'之后,无论'count'的条件是否满足。相反,我更改了解决方案,以便 window 仅标记 'begin':
之前发生的事件event = (f.max(f.col('begin')).over(w.rowsBetween(-2, 0))).\
alias('event_post_only')
# the number of events to mark is 3 from 'begin',
# including the event itself, so that's -2.
df_marked_events = df_marked_events.select('*', event)
然后为在 'event_post_only' 中为真或在 'event_post_only'
中为真的所有事件标记为真df_marked_events = df_marked_events.withColumn('event', (col('count') >= 3) \
| (col('event_post_only')))
这避免将 上游的所有内容 标记为 True 到 'begin' == True
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w = Window.partitionBy('uid').orderBy(F.col('count').desc(), F.col('start_date'))
# find the beginning point of >= 3 events
begin = (
(F.col('count') >= 3) &
(F.lead(F.col('count')).over(w) < 3)
).alias('begin')
df = df.select('*', begin)
# Mark as event if the event is in any rows after begin, or two rows before begin
event = (
F.max(F.col('begin')).over(w.rowsBetween(0, Window.unboundedFollowing)) |
F.max(F.col('begin')).over(w.rowsBetween(-2,0))
).alias('event')
df = df.select('*', event)
df.show()
+-------------------+-----+-----+-----+-----+
| start_date| uid|count|begin|event|
+-------------------+-----+-----+-----+-----+
|2020-11-26 08:30:22|user1| 4.0|false| true|
|2020-11-22 08:37:18|user1| 3.0|false| true|
|2020-11-26 10:00:00|user1| 3.0| true| true|
|2020-11-20 16:04:04|user1| 2.0|false| true|
|2020-11-22 13:32:30|user1| 2.0|false| true|
|2020-11-16 12:04:04|user1| 1.0|false|false|
+-------------------+-----+-----+-----+-----+