根据数据帧条件在 Spark 中创建自定义计数器
Creating a custom counter in Spark based on dataframe conditions
当前数据集
+---+-----+-----+-----+----+
| ID|Event|Index|start| end|
+---+-----+-----+-----+----+
| 1| run| 0|start|null|
| 1| run| 1| null|null|
| 1| run| 2| null|null|
| 1| swim| 3| null| end|
| 1| run| 4|start|null|
| 1| swim| 5| null|null|
| 1| swim| 6| null| end|
| 1| run| 7|start|null|
| 1| run| 8| null|null|
| 1| run| 9| null|null|
| 1| swim| 10| null| end|
| 1| run| 11|start|null|
| 1| run| 12| null|null|
| 1| run| 13| null| end|
| 2| run| 14|start|null|
| 2| run| 15| null|null|
| 2| run| 16| null|null|
| 2| swim| 17| null| end|
| 2| run| 18|start|null|
| 2| swim| 19| null|null|
| 2| swim| 20| null|null|
| 2| swim| 21| null|null|
| 2| swim| 22| null| end|
| 2| run| 23|start|null|
| 2| run| 24| null|null|
| 2| run| 25| null| end|
| 3| run| 26|start|null|
| 3| run| 27| null|null|
| 3| swim| 28| null|null|
+---+-----+-----+-----+----+
我想要的数据集
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventID|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
我正在尝试创建上面的 EventID 列。有没有一种方法可以在 udf 中创建一个根据列条件更新的计数器?请注意,我不确定 UDF 是否是这里的最佳方法。
这是我目前的思维逻辑:
- 当看到“开始”值时,开始计数。
- 当看到“结束”值时,结束计数
- 每次看到一个新的 ID,将计数器重置为 1
感谢大家的帮助。
这是生成当前数据帧的原始代码:
# Current Dataset
data = [
(1, "run", 0, 'start', None),
(1, "run", 1, None, None),
(1, "run", 2, None, None),
(1, "swim", 3, None, 'end'),
(1, "run", 4, 'start',None),
(1, "swim", 5, None, None),
(1, "swim", 6, None, 'end'),
(1, "run",7, 'start', None),
(1, "run",8, None, None),
(1, "run",9, None, None),
(1, "swim",10, None, 'end'),
(1, "run",11, 'start', None),
(1, "run",12, None, None),
(1, "run",13, None, 'end'),
(2, "run",14, 'start', None),
(2, "run",15, None, None),
(2, "run",16, None, None),
(2, "swim",17, None, 'end'),
(2, "run",18, 'start', None),
(2, "swim",19, None, None),
(2, "swim",20, None, None),
(2, "swim",21, None, None),
(2, "swim",22, None, 'end'),
(2, "run",23, 'start', None),
(2, "run",24, None, None),
(2, "run",25, None, 'end'),
(3, "run",26, 'start', None),
(3, "run",27, None, None),
(3, "swim",28, None, None)
]
schema = StructType([
StructField('ID', IntegerType(),True), \
StructField('Event', StringType(),True), \
StructField('Index', IntegerType(),True), \
StructField('start', StringType(),True), \
StructField('end', StringType(),True)
])
df = spark.createDataFrame(data=data, schema=schema)
df.show(30)
您可以使用 window 函数:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w = Window.partitionBy('ID').rowsBetween(Window.unboundedPreceding,0).orderBy('index')
df.withColumn('EventId', F.sum(F.when(F.col('start') == 'start', 1).otherwise(0))\
.over(w)).orderBy('ID', 'Index').show(100)
结果
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventId|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
您可以根据最近的开始时间计算dense_rank:
from pyspark.sql import functions as F, Window
df2 = df.withColumn(
'laststart',
F.last(F.when(F.col('start') == 'start', F.col('Index')), True).over(Window.partitionBy('ID').orderBy('Index'))
).withColumn(
'EventID',
F.dense_rank().over(Window.partitionBy('ID').orderBy('laststart'))
)
df2.show(999)
+---+-----+-----+-----+----+---------+-------+
| ID|Event|Index|start| end|laststart|EventID|
+---+-----+-----+-----+----+---------+-------+
| 1| run| 0|start|null| 0| 1|
| 1| run| 1| null|null| 0| 1|
| 1| run| 2| null|null| 0| 1|
| 1| swim| 3| null| end| 0| 1|
| 1| run| 4|start|null| 4| 2|
| 1| swim| 5| null|null| 4| 2|
| 1| swim| 6| null| end| 4| 2|
| 1| run| 7|start|null| 7| 3|
| 1| run| 8| null|null| 7| 3|
| 1| run| 9| null|null| 7| 3|
| 1| swim| 10| null| end| 7| 3|
| 1| run| 11|start|null| 11| 4|
| 1| run| 12| null|null| 11| 4|
| 1| run| 13| null| end| 11| 4|
| 2| run| 14|start|null| 14| 1|
| 2| run| 15| null|null| 14| 1|
| 2| run| 16| null|null| 14| 1|
| 2| swim| 17| null| end| 14| 1|
| 2| run| 18|start|null| 18| 2|
| 2| swim| 19| null|null| 18| 2|
| 2| swim| 20| null|null| 18| 2|
| 2| swim| 21| null|null| 18| 2|
| 2| swim| 22| null| end| 18| 2|
| 2| run| 23|start|null| 23| 3|
| 2| run| 24| null|null| 23| 3|
| 2| run| 25| null| end| 23| 3|
| 3| run| 26|start|null| 26| 1|
| 3| run| 27| null|null| 26| 1|
| 3| swim| 28| null|null| 26| 1|
+---+-----+-----+-----+----+---------+-------+
当前数据集
+---+-----+-----+-----+----+
| ID|Event|Index|start| end|
+---+-----+-----+-----+----+
| 1| run| 0|start|null|
| 1| run| 1| null|null|
| 1| run| 2| null|null|
| 1| swim| 3| null| end|
| 1| run| 4|start|null|
| 1| swim| 5| null|null|
| 1| swim| 6| null| end|
| 1| run| 7|start|null|
| 1| run| 8| null|null|
| 1| run| 9| null|null|
| 1| swim| 10| null| end|
| 1| run| 11|start|null|
| 1| run| 12| null|null|
| 1| run| 13| null| end|
| 2| run| 14|start|null|
| 2| run| 15| null|null|
| 2| run| 16| null|null|
| 2| swim| 17| null| end|
| 2| run| 18|start|null|
| 2| swim| 19| null|null|
| 2| swim| 20| null|null|
| 2| swim| 21| null|null|
| 2| swim| 22| null| end|
| 2| run| 23|start|null|
| 2| run| 24| null|null|
| 2| run| 25| null| end|
| 3| run| 26|start|null|
| 3| run| 27| null|null|
| 3| swim| 28| null|null|
+---+-----+-----+-----+----+
我想要的数据集
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventID|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
我正在尝试创建上面的 EventID 列。有没有一种方法可以在 udf 中创建一个根据列条件更新的计数器?请注意,我不确定 UDF 是否是这里的最佳方法。
这是我目前的思维逻辑:
- 当看到“开始”值时,开始计数。
- 当看到“结束”值时,结束计数
- 每次看到一个新的 ID,将计数器重置为 1
感谢大家的帮助。
这是生成当前数据帧的原始代码:
# Current Dataset
data = [
(1, "run", 0, 'start', None),
(1, "run", 1, None, None),
(1, "run", 2, None, None),
(1, "swim", 3, None, 'end'),
(1, "run", 4, 'start',None),
(1, "swim", 5, None, None),
(1, "swim", 6, None, 'end'),
(1, "run",7, 'start', None),
(1, "run",8, None, None),
(1, "run",9, None, None),
(1, "swim",10, None, 'end'),
(1, "run",11, 'start', None),
(1, "run",12, None, None),
(1, "run",13, None, 'end'),
(2, "run",14, 'start', None),
(2, "run",15, None, None),
(2, "run",16, None, None),
(2, "swim",17, None, 'end'),
(2, "run",18, 'start', None),
(2, "swim",19, None, None),
(2, "swim",20, None, None),
(2, "swim",21, None, None),
(2, "swim",22, None, 'end'),
(2, "run",23, 'start', None),
(2, "run",24, None, None),
(2, "run",25, None, 'end'),
(3, "run",26, 'start', None),
(3, "run",27, None, None),
(3, "swim",28, None, None)
]
schema = StructType([
StructField('ID', IntegerType(),True), \
StructField('Event', StringType(),True), \
StructField('Index', IntegerType(),True), \
StructField('start', StringType(),True), \
StructField('end', StringType(),True)
])
df = spark.createDataFrame(data=data, schema=schema)
df.show(30)
您可以使用 window 函数:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w = Window.partitionBy('ID').rowsBetween(Window.unboundedPreceding,0).orderBy('index')
df.withColumn('EventId', F.sum(F.when(F.col('start') == 'start', 1).otherwise(0))\
.over(w)).orderBy('ID', 'Index').show(100)
结果
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventId|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
您可以根据最近的开始时间计算dense_rank:
from pyspark.sql import functions as F, Window
df2 = df.withColumn(
'laststart',
F.last(F.when(F.col('start') == 'start', F.col('Index')), True).over(Window.partitionBy('ID').orderBy('Index'))
).withColumn(
'EventID',
F.dense_rank().over(Window.partitionBy('ID').orderBy('laststart'))
)
df2.show(999)
+---+-----+-----+-----+----+---------+-------+
| ID|Event|Index|start| end|laststart|EventID|
+---+-----+-----+-----+----+---------+-------+
| 1| run| 0|start|null| 0| 1|
| 1| run| 1| null|null| 0| 1|
| 1| run| 2| null|null| 0| 1|
| 1| swim| 3| null| end| 0| 1|
| 1| run| 4|start|null| 4| 2|
| 1| swim| 5| null|null| 4| 2|
| 1| swim| 6| null| end| 4| 2|
| 1| run| 7|start|null| 7| 3|
| 1| run| 8| null|null| 7| 3|
| 1| run| 9| null|null| 7| 3|
| 1| swim| 10| null| end| 7| 3|
| 1| run| 11|start|null| 11| 4|
| 1| run| 12| null|null| 11| 4|
| 1| run| 13| null| end| 11| 4|
| 2| run| 14|start|null| 14| 1|
| 2| run| 15| null|null| 14| 1|
| 2| run| 16| null|null| 14| 1|
| 2| swim| 17| null| end| 14| 1|
| 2| run| 18|start|null| 18| 2|
| 2| swim| 19| null|null| 18| 2|
| 2| swim| 20| null|null| 18| 2|
| 2| swim| 21| null|null| 18| 2|
| 2| swim| 22| null| end| 18| 2|
| 2| run| 23|start|null| 23| 3|
| 2| run| 24| null|null| 23| 3|
| 2| run| 25| null| end| 23| 3|
| 3| run| 26|start|null| 26| 1|
| 3| run| 27| null|null| 26| 1|
| 3| swim| 28| null|null| 26| 1|
+---+-----+-----+-----+----+---------+-------+