将数据框列和外部列表传递给 withColumn 下的 udf

Passing a data frame column and external list to udf under withColumn

我有一个具有以下结构的 Spark 数据框。 bodyText_token 有标记(processed/set 个单词)。我有一个已定义关键字的嵌套列表

root
 |-- id: string (nullable = true)
 |-- body: string (nullable = true)
 |-- bodyText_token: array (nullable = true)

keyword_list=[['union','workers','strike','pay','rally','free','immigration',],
['farmer','plants','fruits','workers'],['outside','field','party','clothes','fashions']]

我需要检查每个关键字列表下有多少标记并将结果添加为现有数据框的新列。 例如:如果 tokens =["become", "farmer","rally","workers","student"] 结果将是 -> [1,2,0]

以下功能按预期工作。

def label_maker_topic(tokens,topic_words):
    twt_list = []
    for i in range(0, len(topic_words)):
        count = 0
        #print(topic_words[i])
        for tkn in tokens:
            if tkn in topic_words[i]:
                count += 1
        twt_list.append(count)
    
    return twt_list

我在withColumn下使用udf访问函数,但出现错误。我认为这是关于将外部列表传递给 udf。有没有一种方法可以将外部列表和数据框列传递给 udf 并向我的数据框添加一个新列?

topicWord = udf(label_maker_topic,StringType())
myDF=myDF.withColumn("topic_word_count",topicWord(myDF.bodyText_token,keyword_list))

最干净的解决方案是使用闭包传递额外的参数:

def make_topic_word(topic_words):
     return udf(lambda c: label_maker_topic(c, topic_words))

df = sc.parallelize([(["union"], )]).toDF(["tokens"])

(df.withColumn("topics", make_topic_word(keyword_list)(col("tokens")))
    .show())

这不需要对 keyword_list 或您用 UDF 包装的函数进行任何更改。您还可以使用此方法传递任意对象。这可用于传递例如 sets 的列表以进行高效查找。

如果您想使用当前的 UDF 并直接传递 topic_words,您必须先将其转换为列文字:

from pyspark.sql.functions import array, lit

ks_lit = array(*[array(*[lit(k) for k in ks]) for ks in keyword_list])
df.withColumn("ad", topicWord(col("tokens"), ks_lit)).show()

根据您的数据和要求,可以选择更高效的解决方案,这些解决方案不需要 UDF(分解 + 聚合 + 折叠)或查找(散列 + 向量运算)。

以下工作正常,任何外部参数都可以传递给 UDF(经过调整的代码可以帮助任何人)

topicWord=udf(lambda tkn: label_maker_topic(tkn,topic_words),StringType())
myDF=myDF.withColumn("topic_word_count",topicWord(myDF.bodyText_token))

使用 functools 模块中的 partial 的另一种方式

from functools import partial

func_to_call = partial(label_maker_topic, topic_words=keyword_list)

pyspark_udf = udf(func_to_call, <specify_the_type_returned_by_function_here>)

df = sc.parallelize([(["union"], )]).toDF(["tokens"])

df.withColumn("topics", pyspark_udf(col("tokens"))).show()

如果列表很大,keyword_list列表应该广播到集群中的所有节点。我猜 zero 的解决方案可行,因为列表很小而且是 auto-broadcasted。在我看来,最好明确广播以消除疑虑(更大的列表需要明确广播)。

keyword_list=[
    ['union','workers','strike','pay','rally','free','immigration',],
    ['farmer','plants','fruits','workers'],
    ['outside','field','party','clothes','fashions']]

def label_maker_topic(tokens, topic_words_broadcasted):
    twt_list = []
    for i in range(0, len(topic_words_broadcasted.value)):
        count = 0
        #print(topic_words[i])
        for tkn in tokens:
            if tkn in topic_words_broadcasted.value[i]:
                count += 1
        twt_list.append(count)

    return twt_list

def make_topic_word_better(topic_words_broadcasted):
    def f(c):
        return label_maker_topic(c, topic_words_broadcasted)
    return F.udf(f)

df = spark.createDataFrame([["union",], ["party",]]).toDF("tokens")
b = spark.sparkContext.broadcast(keyword_list)
df.withColumn("topics", make_topic_word_better(b)(F.col("tokens"))).show()

输出结果如下:

+------+---------+
|tokens|   topics|
+------+---------+
| union|[0, 0, 0]|
| party|[0, 0, 0]|
+------+---------+

请注意,您需要调用 value 才能访问已广播的列表(例如 topic_words_broadcasted.value)。这是一个困难的实现,但掌握起来很重要,因为很多 PySpark UDF 都依赖于已广播的列表或字典。