如何将条件计数(带重置)应用于 PySpark 中的分组数据?

How to apply conditional counts (with reset) to grouped data in PySpark?

我有 PySpark 代码,可以有效地按数字对行进行分组,并在满足特定条件时递增。我无法弄清楚如何有效地将此代码转换为可应用于组的代码。

取这个样本数据帧 df

df = sqlContext.createDataFrame(
    [
        (33, [], '2017-01-01'),
        (33, ['apple', 'orange'], '2017-01-02'),
        (33, [], '2017-01-03'),
        (33, ['banana'], '2017-01-04')
    ],
    ('ID', 'X', 'date')
)

这段代码实现了我想要的示例 df,即按日期排序并创建在大小列返回 0 时递增的组 ('grp')。

df \
.withColumn('size', size(col('X'))) \
.withColumn(
    "grp", 
    sum((col('size') == 0).cast("int")).over(Window.orderBy('date'))
).show()

部分基于

现在我要做的是将相同的方法应用于具有多个 ID 的数据框 - 实现看起来像

的结果
df2 = sqlContext.createDataFrame(
    [
        (33, [], '2017-01-01', 0, 1),
        (33, ['apple', 'orange'], '2017-01-02', 2, 1),
        (33, [], '2017-01-03', 0, 2),
        (33, ['banana'], '2017-01-04', 1, 2),
        (55, ['coffee'], '2017-01-01', 1, 1),
        (55, [], '2017-01-03', 0, 2)
    ],
    ('ID', 'X', 'date', 'size', 'group')
)

编辑清楚

1) 对于每个 ID 的第一个日期 - 组应该是 1 - 无论任何其他列中显示什么。

2) 但是,对于每个后续日期,我都需要检查尺寸列。如果大小列为 0,则我增加组号。如果是任何非零的正整数,那我就继续前面的组号。

我在 pandas 中看到了一些处理此问题的方法,但我很难理解 pyspark 中的应用程序以及分组数据在 pandas 与 spark 中的不同方式(例如,我需要使用称为 UADF 的东西吗?)

我添加了一个window函数,并在每个ID中创建了一个索引。然后我扩展了条件语句以也引用该索引。以下似乎产生了我想要的输出数据帧 - 但我想知道是否有更有效的方法来做到这一点。

window = Window.partitionBy('ID').orderBy('date')
df \
.withColumn('size', size(col('X'))) \
.withColumn('index', rank().over(window).alias('index')) \
.withColumn(
    "grp", 
    sum(((col('size') == 0) | (col('index') == 1)).cast("int")).over(window)
).show()

产生

+---+---------------+----------+----+-----+---+
| ID|              X|      date|size|index|grp|
+---+---------------+----------+----+-----+---+
| 33|             []|2017-01-01|   0|    1|  1|
| 33|[apple, orange]|2017-01-02|   2|    2|  1|
| 33|             []|2017-01-03|   0|    3|  2|
| 33|       [banana]|2017-01-04|   1|    4|  2|
| 55|       [coffee]|2017-01-01|   1|    1|  1|
| 55|             []|2017-01-03|   0|    2|  2|
+---+---------------+----------+----+-----+---+

通过检查 size 是否为零或该行是否为第一行来创建列 zero_or_first。然后sum.

df2 = sqlContext.createDataFrame(
    [
        (33, [], '2017-01-01', 0, 1),
        (33, ['apple', 'orange'], '2017-01-02', 2, 1),
        (33, [], '2017-01-03', 0, 2),
        (33, ['banana'], '2017-01-04', 1, 2),
        (55, ['coffee'], '2017-01-01', 1, 1),
        (55, [], '2017-01-03', 0, 2),
        (55, ['banana'], '2017-01-01', 1, 1)
    ],
    ('ID', 'X', 'date', 'size', 'group')
)


w = Window.partitionBy('ID').orderBy('date')
df2 = df2.withColumn('row', F.row_number().over(w))
df2 = df2.withColumn('zero_or_first', F.when((F.col('size')==0)|(F.col('row')==1), 1).otherwise(0))
df2 = df2.withColumn('grp', F.sum('zero_or_first').over(w))
df2.orderBy('ID').show()

这是'输出。您可以看到该列 group == grp。其中 group 是预期结果。

+---+---------------+----------+----+-----+---+-------------+---+
| ID|              X|      date|size|group|row|zero_or_first|grp|
+---+---------------+----------+----+-----+---+-------------+---+
| 33|             []|2017-01-01|   0|    1|  1|            1|  1|
| 33|       [banana]|2017-01-04|   1|    2|  4|            0|  2|
| 33|[apple, orange]|2017-01-02|   2|    1|  2|            0|  1|
| 33|             []|2017-01-03|   0|    2|  3|            1|  2|
| 55|       [coffee]|2017-01-01|   1|    1|  1|            1|  1|
| 55|       [banana]|2017-01-01|   1|    1|  2|            0|  1|
| 55|             []|2017-01-03|   0|    2|  3|            1|  2|
+---+---------------+----------+----+-----+---+-------------+---+