如何将多个参数传递给 PySpark 中的 Pandas UDF?

How do I pass multiple arguments to a Pandas UDF in PySpark?

我正在处理以下片段:

from cape_privacy.pandas.transformations import Tokenizer

max_token_len = 5


@pandas_udf("string")

def Tokenize(column: pd.Series)-> pd.Series:
  tokenizer = Tokenizer(max_token_len)
  return tokenizer(column)


spark_df = spark_df.withColumn("name", Tokenize("name"))

由于 Pandas UDF 仅使用 Pandas 系列,因此我无法在函数调用 Tokenize("name").

中传递 max_token_len 参数

因此我必须在函数范围之外定义 max_token_len 参数。

this question 中提供的变通办法并不是很有帮助。 对于此问题是否有任何其他可能的解决方法或替代方法?

请指教

您可以通过使用 partial 并在您的 UDF 签名

中直接指定额外的 参数 来实现此目的

数据准备

input_list = [
               (1,None,111)    
               ,(1,None,120)
              ,(1,None,121)
              ,(1,None,124)
              ,(1,'p1',125)
              ,(1,None,126)
              ,(1,None,146)
              ,(1,None,147)
             ]

sparkDF = sql.createDataFrame(input_list,['id','p_id','timestamp'])

sparkDF.show()

+---+----+---------+
| id|p_id|timestamp|
+---+----+---------+
|  1|null|      111|
|  1|null|      120|
|  1|null|      121|
|  1|null|      124|
|  1|  p1|      125|
|  1|null|      126|
|  1|null|      146|
|  1|null|      147|
+---+----+---------+

部分


def add_constant(inp,cnst=5):
    return inp + cnst


cnst_add = 10

partial_func = partial(add_constant,cnst=cnst_add)

sparkDF = sparkDF.withColumn('Constant',partial_func(F.col('timestamp')))
                 
sparkDF.show()

+---+----+---------+----------------+
| id|p_id|timestamp|Constant_Partial|
+---+----+---------+----------------+
|  1|null|      111|             121|
|  1|null|      120|             130|
|  1|null|      121|             131|
|  1|null|      124|             134|
|  1|  p1|      125|             135|
|  1|null|      126|             136|
|  1|null|      146|             156|
|  1|null|      147|             157|
+---+----+---------+----------------+

UDF 签名

cnst_add = 10

add_constant_udf = F.udf(lambda x : add_constant(x,cnst_add),IntegerType())


sparkDF = sparkDF.withColumn('Constant_UDF',add_constant_udf(F.col('timestamp')))

sparkDF.show()

+---+----+---------+------------+
| id|p_id|timestamp|Constant_UDF|
+---+----+---------+------------+
|  1|null|      111|         121|
|  1|null|      120|         130|
|  1|null|      121|         131|
|  1|null|      124|         134|
|  1|  p1|      125|         135|
|  1|null|      126|         136|
|  1|null|      146|         156|
|  1|null|      147|         157|
+---+----+---------+------------+

类似地,您可以按如下方式转换函数 -

from functools import partial

max_token_len = 5

def Tokenize(column: pd.Series,max_token_len=10)-> pd.Series:
  tokenizer = Tokenizer(max_token_len)
  return tokenizer(column)

Tokenize_udf = F.udf(lambda x : Tokenize(x,max_token_len),StringType())

Tokenize_partial = partial(Tokenize,max_token_len=max_token_len)

spark_df = spark_df.withColumn("name", Tokenize_udf("name"))
spark_df = spark_df.withColumn("name", Tokenize_partial("name"))

在尝试了无数种方法之后,我找到了一个毫不费力的解决方案,如下图所示:

我创建了一个 wrapper 函数 (Tokenize_wrapper) 来包装 Pandas UDF (Tokenize_udf) 包装函数返回 Pandas UDF 的 函数调用。

def Tokenize_wrapper(column, max_token_len=10):

  @pandas_udf("string")
  def Tokenize_udf(column: pd.Series) -> pd.Series:
    tokenizer = Tokenizer(max_token_len)
    return tokenizer(column)

  return Tokenize_udf(column)



df = df.withColumn("Name", Tokenize_wrapper("Name", max_token_len=5))

使用部分函数(@Vaebhav 的回答)确实使这个问题的实现变得困难。