将 UDF 重写为 pandas UDF Pyspark
Rewrite UDF to pandas UDF Pyspark
我有一个数据框:
import pyspark.sql.functions as F
sdf1 = spark.createDataFrame(
[
(2022, 1, ["apple", "edible"]),
(2022, 1, ["edible", "fruit"]),
(2022, 1, ["orange", "sweet"]),
(2022, 4, ["flowering ", "plant"]),
(2022, 3, ["green", "kiwi"]),
(2022, 3, ["kiwi", "fruit"]),
(2022, 3, ["fruit", "popular"]),
(2022, 3, ["yellow", "lemon"]),
],
[
"year",
"id",
"bigram",
],
)
sdf1.show(truncate=False)
+----+---+-------------------+
|year|id |bigram |
+----+---+-------------------+
|2022|1 |[apple, edible] |
|2022|1 |[edible, fruit] |
|2022|1 |[orange, sweet] |
|2022|4 |[flowering , plant]|
|2022|3 |[green, kiwi] |
|2022|3 |[kiwi, fruit] |
|2022|3 |[fruit, popular] |
|2022|3 |[yellow, lemon] |
+----+---+-------------------+
然后我写了一个函数,returns bigrams with the same last words in n-grams.I apply this function separately to the column.
from networkx import DiGraph, dfs_labeled_edges
# Grouping
sdf = (
sdf1.groupby("year", "id")
.agg(F.collect_set("bigram").alias("collect_bigramm"))
.withColumn("size", F.size("collect_bigramm"))
)
data_collect = sdf.collect()
@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
graph = DiGraph()
for row in data_collect:
if row["size"] > 1:
for i, lst1 in enumerate(lst):
while i < len(lst) - 1:
lst2 = lst[i + 1]
if lst1[0] == lst2[1]:
graph.add_edge(lst2[0], lst2[1])
graph.add_edge(lst1[0], lst1[1])
elif lst1[1] == lst2[0]:
graph.add_edge(lst1[0], lst1[1])
graph.add_edge(lst2[0], lst2[1])
i = i + 1
gen = dfs_labeled_edges(graph)
lst_tmp = []
lst_res = []
f = 0
for g in list(gen):
if (g[2] == "forward") and (g[0] != g[1]):
f = 1
lst_tmp.append(g[0])
lst_tmp.append(g[1])
if g[2] == "nontree":
continue
if g[2] == "reverse":
if f == 1:
lst_res.append(lst_tmp.copy())
f = 0
if g[0] in lst_tmp:
lst_tmp.remove(g[0])
if g[1] in lst_tmp:
lst_tmp.remove(g[1])
if lst_res != []:
lst_res = [
ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
]
if lst_res == []:
lst_res = None
return lst_res
sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)
输出:
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm |size|new_col |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4 |[[flowering , plant]] |1 |null |
|2022|1 |[[edible, fruit], [orange, sweet], [apple, edible]] |3 |[apple, edible, fruit] |
|2022|3 |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4 |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+
但现在我想使用 pandas udf。我想首先 groupby 并获得函数中的 collect_bigramm
列。从而保留数据框中的所有列,但也添加一个新列,即函数中的 lst_res
数组。
schema2 = StructType(
[
StructField("year", IntegerType(), True),
StructField("id", IntegerType(), True),
StructField("bigram", ArrayType(StringType(), True), True),
StructField("new_col", ArrayType(StringType(), True), True),
StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
]
)
@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):
graph = DiGraph()
for index, row in df.iterrows():
# Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
...
return df
sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)
你不想运行groupBy
两次(一次sdf1
,一次pandas_udf
),它只会扼杀 pandas_udf
的“对记录列表进行分组,然后将其矢量化,然后发送给工作人员”的想法。你会想做这样的事情 sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)
你的 UDF 现在是一个“熊猫 UDF”,它实际上只是一个 Python 函数,取一个 Pandas DF 和 return 另一个 Pandas UDF。有了这个意思,您甚至可以 运行 没有 Spark 的功能 。这里的技巧就是如何构建数据框来满足您的需求。检查下面的 运行ning 代码,我保留了你的大部分 networkx 代码,只是从输入和输出中修复了一点。
def myfunc(pdf):
pdf = (pdf
.groupby(['year', 'id'])['bigram']
.agg(list=list, len=len) # you might want to fix the list here to set
.reset_index()
.rename(columns={
'list': 'collect_bigram',
'len': 'size',
})
)
graph = DiGraph()
if pdf['size'][0] > 1:
lst = pdf['collect_bigram'][0]
for i, lst1 in enumerate(lst):
... # same as original code
if lst_res == []:
lst_res = None
pdf['new_col'] = [lst_res]
else:
pdf['new_col'] = None
return pdf
我有一个数据框:
import pyspark.sql.functions as F
sdf1 = spark.createDataFrame(
[
(2022, 1, ["apple", "edible"]),
(2022, 1, ["edible", "fruit"]),
(2022, 1, ["orange", "sweet"]),
(2022, 4, ["flowering ", "plant"]),
(2022, 3, ["green", "kiwi"]),
(2022, 3, ["kiwi", "fruit"]),
(2022, 3, ["fruit", "popular"]),
(2022, 3, ["yellow", "lemon"]),
],
[
"year",
"id",
"bigram",
],
)
sdf1.show(truncate=False)
+----+---+-------------------+
|year|id |bigram |
+----+---+-------------------+
|2022|1 |[apple, edible] |
|2022|1 |[edible, fruit] |
|2022|1 |[orange, sweet] |
|2022|4 |[flowering , plant]|
|2022|3 |[green, kiwi] |
|2022|3 |[kiwi, fruit] |
|2022|3 |[fruit, popular] |
|2022|3 |[yellow, lemon] |
+----+---+-------------------+
然后我写了一个函数,returns bigrams with the same last words in n-grams.I apply this function separately to the column.
from networkx import DiGraph, dfs_labeled_edges
# Grouping
sdf = (
sdf1.groupby("year", "id")
.agg(F.collect_set("bigram").alias("collect_bigramm"))
.withColumn("size", F.size("collect_bigramm"))
)
data_collect = sdf.collect()
@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
graph = DiGraph()
for row in data_collect:
if row["size"] > 1:
for i, lst1 in enumerate(lst):
while i < len(lst) - 1:
lst2 = lst[i + 1]
if lst1[0] == lst2[1]:
graph.add_edge(lst2[0], lst2[1])
graph.add_edge(lst1[0], lst1[1])
elif lst1[1] == lst2[0]:
graph.add_edge(lst1[0], lst1[1])
graph.add_edge(lst2[0], lst2[1])
i = i + 1
gen = dfs_labeled_edges(graph)
lst_tmp = []
lst_res = []
f = 0
for g in list(gen):
if (g[2] == "forward") and (g[0] != g[1]):
f = 1
lst_tmp.append(g[0])
lst_tmp.append(g[1])
if g[2] == "nontree":
continue
if g[2] == "reverse":
if f == 1:
lst_res.append(lst_tmp.copy())
f = 0
if g[0] in lst_tmp:
lst_tmp.remove(g[0])
if g[1] in lst_tmp:
lst_tmp.remove(g[1])
if lst_res != []:
lst_res = [
ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
]
if lst_res == []:
lst_res = None
return lst_res
sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)
输出:
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm |size|new_col |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4 |[[flowering , plant]] |1 |null |
|2022|1 |[[edible, fruit], [orange, sweet], [apple, edible]] |3 |[apple, edible, fruit] |
|2022|3 |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4 |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+
但现在我想使用 pandas udf。我想首先 groupby 并获得函数中的 collect_bigramm
列。从而保留数据框中的所有列,但也添加一个新列,即函数中的 lst_res
数组。
schema2 = StructType(
[
StructField("year", IntegerType(), True),
StructField("id", IntegerType(), True),
StructField("bigram", ArrayType(StringType(), True), True),
StructField("new_col", ArrayType(StringType(), True), True),
StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
]
)
@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):
graph = DiGraph()
for index, row in df.iterrows():
# Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
...
return df
sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)
你不想运行
groupBy
两次(一次sdf1
,一次pandas_udf
),它只会扼杀pandas_udf
的“对记录列表进行分组,然后将其矢量化,然后发送给工作人员”的想法。你会想做这样的事情sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)
你的 UDF 现在是一个“熊猫 UDF”,它实际上只是一个 Python 函数,取一个 Pandas DF 和 return 另一个 Pandas UDF。有了这个意思,您甚至可以 运行 没有 Spark 的功能 。这里的技巧就是如何构建数据框来满足您的需求。检查下面的 运行ning 代码,我保留了你的大部分 networkx 代码,只是从输入和输出中修复了一点。
def myfunc(pdf):
pdf = (pdf
.groupby(['year', 'id'])['bigram']
.agg(list=list, len=len) # you might want to fix the list here to set
.reset_index()
.rename(columns={
'list': 'collect_bigram',
'len': 'size',
})
)
graph = DiGraph()
if pdf['size'][0] > 1:
lst = pdf['collect_bigram'][0]
for i, lst1 in enumerate(lst):
... # same as original code
if lst_res == []:
lst_res = None
pdf['new_col'] = [lst_res]
else:
pdf['new_col'] = None
return pdf