将 UDF 重写为 pandas UDF Pyspark

Rewrite UDF to pandas UDF Pyspark

我有一个数据框:

import pyspark.sql.functions as F

sdf1 = spark.createDataFrame(
    [
        (2022, 1, ["apple", "edible"]),
        (2022, 1, ["edible", "fruit"]),
        (2022, 1, ["orange", "sweet"]),
        (2022, 4, ["flowering ", "plant"]),
        (2022, 3, ["green", "kiwi"]),
        (2022, 3, ["kiwi", "fruit"]),
        (2022, 3, ["fruit", "popular"]),
        (2022, 3, ["yellow", "lemon"]),
    ],
    [
        "year",
        "id",
        "bigram",
    ],
)
sdf1.show(truncate=False)

    +----+---+-------------------+
    |year|id |bigram             |
    +----+---+-------------------+
    |2022|1  |[apple, edible]    |
    |2022|1  |[edible, fruit]    |
    |2022|1  |[orange, sweet]    |
    |2022|4  |[flowering , plant]|
    |2022|3  |[green, kiwi]      |
    |2022|3  |[kiwi, fruit]      |
    |2022|3  |[fruit, popular]   |
    |2022|3  |[yellow, lemon]    |
    +----+---+-------------------+

然后我写了一个函数,returns bigrams with the same last words in n-grams.I apply this function separately to the column.

from networkx import DiGraph, dfs_labeled_edges

# Grouping
sdf = (
    sdf1.groupby("year", "id")
    .agg(F.collect_set("bigram").alias("collect_bigramm"))
    .withColumn("size", F.size("collect_bigramm"))
)

data_collect = sdf.collect()


@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
    graph = DiGraph()

    for row in data_collect:
        if row["size"] > 1:
            for i, lst1 in enumerate(lst):
                while i < len(lst) - 1:
                    lst2 = lst[i + 1]
                    if lst1[0] == lst2[1]:
                        graph.add_edge(lst2[0], lst2[1])
                        graph.add_edge(lst1[0], lst1[1])
                    elif lst1[1] == lst2[0]:
                        graph.add_edge(lst1[0], lst1[1])
                        graph.add_edge(lst2[0], lst2[1])
                    i = i + 1

            gen = dfs_labeled_edges(graph)
            lst_tmp = []
            lst_res = []
            f = 0
            for g in list(gen):
                if (g[2] == "forward") and (g[0] != g[1]):
                    f = 1
                    lst_tmp.append(g[0])
                    lst_tmp.append(g[1])

                if g[2] == "nontree":
                    continue
                if g[2] == "reverse":
                    if f == 1:
                        lst_res.append(lst_tmp.copy())
                    f = 0
                    if g[0] in lst_tmp:
                        lst_tmp.remove(g[0])
                    if g[1] in lst_tmp:
                        lst_tmp.remove(g[1])

            if lst_res != []:
                lst_res = [
                    ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
                ]
            if lst_res == []:
                lst_res = None
            return lst_res


sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)

输出:

+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm                                                          |size|new_col                      |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4  |[[flowering , plant]]                                            |1   |null                         |
|2022|1  |[[edible, fruit], [orange, sweet], [apple, edible]]              |3   |[apple, edible, fruit]       |
|2022|3  |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4   |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+

但现在我想使用 pandas udf。我想首先 groupby 并获得函数中的 collect_bigramm 列。从而保留数据框中的所有列,但也添加一个新列,即函数中的 lst_res 数组。


schema2 = StructType(
    [
        StructField("year", IntegerType(), True),
        StructField("id", IntegerType(), True),
        StructField("bigram", ArrayType(StringType(), True), True),
        StructField("new_col", ArrayType(StringType(), True), True),
        StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
    ]
)


@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):

    graph = DiGraph()
    for index, row in df.iterrows():
        # Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
        ...

    return df


sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)
  1. 不想运行groupBy两次(一次sdf1,一次pandas_udf),它只会扼杀 pandas_udf 的“对记录列表进行分组,然后将其矢量化,然后发送给工作人员”的想法。你会想做这样的事情 sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)

  2. 你的 UDF 现在是一个“熊猫 UDF”,它实际上只是一个 Python 函数,取一个 Pandas DF 和 return 另一个 Pandas UDF。有了这个意思,您甚至可以 运行 没有 Spark 的功能 。这里的技巧就是如何构建数据框来满足您的需求。检查下面的 运行ning 代码,我保留了你的大部分 networkx 代码,只是从输入和输出中修复了一点。

def myfunc(pdf):
    pdf = (pdf
        .groupby(['year', 'id'])['bigram']
        .agg(list=list, len=len) # you might want to fix the list here to set
        .reset_index()
        .rename(columns={
            'list': 'collect_bigram',
            'len': 'size',
        })
    )

    graph = DiGraph()
    if pdf['size'][0] > 1:
        lst = pdf['collect_bigram'][0]
        for i, lst1 in enumerate(lst):
        ... # same as original code
        if lst_res == []:
            lst_res = None
        pdf['new_col'] = [lst_res]
    else:
        pdf['new_col'] = None
    return pdf