如何在pyspark中与包含数据框数组的行相交

How to intersect rows containing an array for a dataframe in pyspark

我有一个数据框

   df = spark.createDataFrame(
    [(2022, 1, 3, '01', ['apple', 'banana', 'orange'],
      [['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
       ['source', 'Vitamin C', 'fruit']], [['fruit', 2], ['Vitamin', 2]]),
     (2022, 1, 3, '02', ['apple', 'banana', 'avocado'],
     [['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
      ['medium', 'dark', 'green', 'fruit']], [['fruit', 3], ['green', 2]]),
     (2022, 2, 4, '03', ['pomelo', 'fig'],
     [['citrus', 'fruit', 'sweet'], ['soft', 'sweet']], [['sweet', 2]]), ],
    ['year', 'month', 'day', 'id', "list_of_fruits",
        'collected_tokens', 'most_common_word']
)

+----+-----+---+---+------------------------+------------------------------------------------------------------------------------------------------------------------+--------------------------+
|year|month|day|id |list_of_fruits          |collected_tokens                                                                                                        |most_common_word          |
+----+-----+---+---+------------------------+------------------------------------------------------------------------------------------------------------------------+--------------------------+
|2022|1    |3  |01 |[apple, banana, orange] |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit], [source, Vitamin C, fruit]]  |[[fruit, 2], [Vitamin, 2]]|
|2022|1    |3  |02 |[apple, banana, avocado]|[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit], [medium, dark, green, fruit]]|[[fruit, 3], [green, 2]]  |
|2022|2    |4  |03 |[pomelo, fig]           |[[citrus, fruit, sweet], [soft, sweet]]                                                                                 |[[sweet, 2]]              |
+----+-----+---+---+------------------------+------------------------------------------------------------------------------------------------------------------------+--------------------------

我想按年、日和月分组,并与包含列表、列表列表和具有键和最小值的列表(分别为最后三列)的行相交。最后,我想要这样的结果

+----+-----+---+---+---------------------------+------------------------------------------------------------------------------------------+-----------------------------+
|year|month|day|id |intersection_list_of_fruits|intersection_collected_tokens                                                             |intersection_most_common_word|
+----+-----+---+---+---------------------------+------------------------------------------------------------------------------------------+-----------------------------+
|2022|1    |3  |01 |[apple, banana]            |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]]                 |
|2022|1    |3  |02 |[apple, banana]            |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]]                 |
|2022|2    |4  |03 |[pomelo, fig]              |[[citrus, fruit, sweet], [soft, sweet]]                                                   |[[sweet, 2]]                 |
+----+-----+---+---+---------------------------+------------------------------------------------------------------------------------------+-----------------------------+

所以在 intersection_list_of_fruits 列中缺少 [orange],[avocado],在 intersection_collected_tokens 列中缺少 [source, Vitamin C, fruit], [medium, dark, green, fruit] 并且在 intersection_most_common_word 列中缺少 [Vitamin, 2], [green, 2] .

我知道 array_intersect,但我需要按行查看 交集 ,并且还需要使用聚合函数,因为 groupby - 对 id 进行分组具有相同的日期并将它们相交。 (我认为这可以使用 spark 的 applyInPandas 函数来完成)

您可以使用 aggregatearray_intersect,连同 collect_set 计算 list_of_fruitscollected_tokens 的交集以获得 intersection_list_of_fruitsintersection_collected_tokens.

但是,由于intersection_most_common_word需要考虑字数。为此,

  1. 找到不包括计数的单词的交集
  2. 遍历most_common_word中的交集词和collect数组,求出最小的个数
from pyspark.sql import functions as F
from pyspark.sql import Window as W
from pyspark.sql import Column

df = spark.createDataFrame(
    [(2022, 1, 3, '01', ['apple', 'banana', 'orange'],
      [['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
       ['source', 'Vitamin C', 'fruit']], [['fruit', 2], ['Vitamin', 2]]),
     (2022, 1, 3, '02', ['apple', 'banana', 'avocado'],
     [['apple', 'edible', 'fruit', 'green'], ['largest', 'herbaceous', 'flowering', 'plant', 'Vitamin B', 'fruit'],
      ['medium', 'dark', 'green', 'fruit']], [['fruit', 3], ['green', 2]]),
     (2022, 2, 4, '03', ['pomelo', 'fig'],
     [['citrus', 'fruit', 'sweet'], ['soft', 'sweet']], [['sweet', 2]]), ],
    ['year', 'month', 'day', 'id', "list_of_fruits",
        'collected_tokens', 'most_common_word']
)

def intersection_expr(col_name: str, window_spec: W) -> Column:
    lists = F.collect_set(col_name).over(window_spec)
    return F.aggregate(lists, lists[0], lambda acc,x: F.array_intersect(acc, x))



def intersect_min(col_name: str, window_spec: W) -> Column:
    # Convert array into map of word and count and collect into set
    k = F.transform(F.col(col_name), lambda x: x[0])
    v = F.transform(F.col(col_name), lambda x: x[1])
    map_count = F.map_from_arrays(k, v)
    map_counts = F.collect_list(map_count).over(window_spec)
    
    # Find keys present in all list
    keys = F.transform(map_counts, lambda x: F.map_keys(x))
    intersected = F.aggregate(keys, keys[0], lambda acc,x: F.array_intersect(acc, x))
    
    # For intersection find the minimum value
    res = F.transform(intersected, lambda key: F.array(key, F.array_min(F.transform(map_counts, lambda m: m.getField(key)))))
    
    return res

window_spec = W.partitionBy("year", "month", "day").orderBy("id").rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

(df.select("year", "month", "day", "id",
        intersection_expr("list_of_fruits", window_spec).alias("intersection_list_of_fruits"), 
        intersection_expr("collected_tokens", window_spec).alias("intersection_collected_tokens"),
        intersect_min("most_common_word", window_spec).alias("intersection_most_common_word"))
    .show(truncate=False))


"""
+----+-----+---+---+---------------------------+------------------------------------------------------------------------------------------+-----------------------------+
|year|month|day|id |intersection_list_of_fruits|intersection_collected_tokens                                                             |intersection_most_common_word|
+----+-----+---+---+---------------------------+------------------------------------------------------------------------------------------+-----------------------------+
|2022|1    |3  |01 |[apple, banana]            |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]]                 |
|2022|1    |3  |02 |[apple, banana]            |[[apple, edible, fruit, green], [largest, herbaceous, flowering, plant, Vitamin B, fruit]]|[[fruit, 2]]                 |
|2022|2    |4  |03 |[pomelo, fig]              |[[citrus, fruit, sweet], [soft, sweet]]                                                   |[[sweet, 2]]                 |
+----+-----+---+---+---------------------------+------------------------------------------------------------------------------------------+-----------------------------+
"""