如何将 N 行分配到 X 组并在 PySpark 中分配一个值 D?

How to distribute N rows into X groups and to attribute a value D in PySpark?

我想做什么:

PySpark 中,我试图将 N 行分配到相同大小的 X 组 中,并 分配一个这些组的特定值 D


我有什么 (df1) :

df1 = spark.createDataFrame([ ('1234','banana','Paris'),
                            ('1235','orange','Berlin'),
                            ('1236','orange','Paris'),
                            ('1237','banana','Berlin'),
                            ('1238','orange','Paris'),
                            ('1239','banana','Berlin'),
                       ], ["A","B","C"])

+----+------+------+
|   A|     B|     C|
+----+------+------+
|1234|banana| Paris|
|1235|orange|Berlin|
|1236|orange| Paris|
|1237|banana|Berlin|
|1238|orange| Paris|
|1239|banana|Berlin|
+----+------+------+

我想要的 (df2) :

例如当 X = 3:

    +----+------+------+-----+
    |   A|     B|     C|    D|
    +----+------+------+-----+
    |1234|banana| Paris|date1|
    |1235|orange|Berlin|date1|
    |1236|orange| Paris|date2|
    |1237|banana|Berlin|date3|
    |1238|orange| Paris|date2|
    |1239|banana|Berlin|date3|
    +----+------+------+-----+

例如当 X = 4:

    +----+------+------+-----+
    |   A|     B|     C|    D|
    +----+------+------+-----+
    |1234|banana| Paris|date1|
    |1235|orange|Berlin|date4|
    |1236|orange| Paris|date2|
    |1237|banana|Berlin|date3|
    |1238|orange| Paris|date2|
    |1239|banana|Berlin|date3|
    +----+------+------+-----+

               
                   

例如当 X = 5:

    +----+------+------+-----+
    |   A|     B|     C|    D|
    +----+------+------+-----+
    |1234|banana| Paris|date1|
    |1235|orange|Berlin|date4|
    |1236|orange| Paris|date2|
    |1237|banana|Berlin|date3|
    |1238|orange| Paris|date2|
    |1239|banana|Berlin|date3|
    +----+------+------+-----+

               

注意:{B,C}元素的排名可以是随机的。


到目前为止我尝试了什么:

以下代码平均分配元素但不遵守不拆分相似 {B;C} 组合的条件

>>> w=Window.orderBy('B','C')
>>> df2 = df1.withColumn("id",(F.row_number().over(w))%3)
>>> df2.show()
+----+------+------+---+
|   A|     B|     C| id|
+----+------+------+---+
|1237|banana|Berlin|  1|
|1239|banana|Berlin|  2|
|1234|banana| Paris|  0|
|1235|orange|Berlin|  1|
|1236|orange| Paris|  2|
|1238|orange| Paris|  0|
+----+------+------+---+

                   
                   

使用 dense_rank 而不是 row_number。如果你 mod 3,你不能保证得到相同大小的组,但它会接近取决于你的数据洗牌。如果需要尽可能准确,您可以将其拆分为 floor(dense_rank_col / max(dense_rank_col) * 3)

向我提出了如下替代答案:


利用collect_listexplode

df1 = spark.createDataFrame([ ('1234','banana','Paris'),
                            ('1235','orange','Berlin'),
                            ('1236','orange','Paris'),
                            ('1237','banana','Berlin'),
                            ('1238','orange','Paris'),
                            ('1239','banana','Berlin'),
                       ], ["A","B","C"])

from pyspark.sql import Window as W, functions as F

df = df1.groupBy("B", "C").agg(F.collect_list("A").alias("A"))\
        .withColumn("id", F.rand())\
        .withColumn("id", F.row_number().over(W.partitionBy().orderBy("id")) % 3)\
        .withColumn("A", F.explode("A"))\
df.show()

+------+------+----+---+
|     B|     C|   A| id|
+------+------+----+---+
|banana|Berlin|1237|  1|
|banana|Berlin|1239|  1|
|orange|Berlin|1235|  2|
|orange| Paris|1236|  0|
|orange| Paris|1238|  0|
|banana| Paris|1234|  1|
+------+------+----+---+

结果与 PySpark Helper

提供的答案完全相同