使用 pandas_udf 和逻辑语句的 Spark 异常错误

Spark exception error using pandas_udf with logical statement

我正在尝试部署一个简单的 if-else 函数,专门使用 pandas_udf。 这是代码:

from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql.types import *
import pandas as pd

@pandas_udf("string", PandasUDFType.SCALAR )
def seq_sum1(col1,col2):
  if col1 + col2 <= 6:
    v = "low"
  elif ((col1 + col2 > 6) & (col1 + col2 <=10)) :
    v = "medium"
  else:
    v = "High"
  return (v)

# Deploy 
df.select("*",seq_sum1('c1','c2').alias('new_col')).show(10)

这会导致错误:

PythonException: An exception was thrown from a UDF: 'ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().', from <command-1220380192863042>, line 13. Full traceback below:

如果我部署相同的代码但使用@udf 而不是@pandas_udf,它会产生预期的结果。 但是,pandas_udf 似乎不起作用。

我知道这种功能可以通过spark中的其他方式实现(case when等),所以这里的重点是我想了解pandas_udf在处理这种逻辑时是如何工作的。

谢谢

UDF 应该采用 pandas 系列和 return pandas 系列,而不是采用和 returning 字符串。

import pandas as pd
import numpy as np
import pyspark.sql.functions as F
import pyspark.sql.types as T

@F.pandas_udf("string", F.PandasUDFType.SCALAR)
def seq_sum1(col1, col2):
    return pd.Series(
        np.where(
            col1 + col2 <= 6, "low",
            np.where(
                (col1 + col2 > 6) & (col1 + col2 <= 10), "medium",
                    "high"
            )
        )
    )

df.select("*", seq_sum1('c1','c2').alias('new_col')).show()
+---+---+-------+
| c1| c2|new_col|
+---+---+-------+
|  1|  2|    low|
|  3|  4| medium|
|  5|  6|   high|
+---+---+-------+

@mck 提供了见解,我最终使用 map 函数来解决它。

@pandas_udf("string", PandasUDFType.SCALAR)
def seq_sum(col1):
  
  # actual function/calculation goes here
  def main(x):
    if x < 6:
      v = "low"
    else:
      v = "high"
    return(v)
  
  # now apply map function, returning a panda series
  result = pd.Series(map(main,col1))
   
  return (result)

df.select("*",seq_sum('column_name').alias('new_col')).show(10)