Databricks 中 UDF 的错误输出

Wrong output out of UDF in Databricks

我想添加一个名为 academics_category 的新列,其中包含值 academic degree 和 no academic degree。 我创建了一个 udf 函数来检查 bildungsstand(教育)是否与学位相匹配。

问题是,输出中的每个值都不是学历。

from pyspark.sql.types import *

def academics_category(academics):
  if academics == "Bachelors":
    return "academic degree"
  elif academics == "Masters":
    return "academic degree"
  else:
    return "no academic degree"
  
academics_udf = udf(academics_category,StringType())
dfAdult = dfAdult.withColumn('academics_category',academics_udf(dfAdult['bildungsstand']))


bildung = dfAdult.groupBy('bildungsstand','bildungslevel').count().sort('bildungslevel').show(20)

+-------------+-------------+-----+
|bildungsstand|bildungslevel|count|
+-------------+-------------+-----+
|    Preschool|          1.0|   51|
|      1st-4th|          2.0|  168|
|      5th-6th|          3.0|  333|
|      7th-8th|          4.0|  646|
|          9th|          5.0|  514|
|         10th|          6.0|  933|
|         11th|          7.0| 1175|
|         12th|          8.0|  433|
|      HS-grad|          9.0|10501|
| Some-college|         10.0| 7291|
|    Assoc-voc|         11.0| 1382|
|   Assoc-acdm|         12.0| 1067|
|    Bachelors|         13.0| 5355|
|      Masters|         14.0| 1723|
|  Prof-school|         15.0|  576|
|    Doctorate|         16.0|  413|
+-------------+-------------+-----+

UDF 不是最佳解决方案,尤其是对于 Python - 主要是因为需要在 JVM 和 Python 之间发送数据。仅在必要时,建议使用从性能角度来看更好的 Pandas UDFs

但在您的情况下,您可以像这样使用内置 when function

>>> from pyspark.sql.functions import when,col
>>> df = spark.createDataFrame([("Bachelors", 13.0), 
       ("Masters", 14.0), ("Preschool", 1.0)], 
       schema=["bildungsstand", "bildungslevel"])
>>> df2 = df.withColumn("academics_category", 
   when((col("bildungsstand") == "Bachelors") | (col("bildungsstand") == "Masters"), 
      "academic degree").otherwise("no academic degree"))
>>> df2.show()
+-------------+-------------+------------------+
|bildungsstand|bildungslevel|academics_category|
+-------------+-------------+------------------+
|    Bachelors|         13.0|   academic degree|
|      Masters|         14.0|   academic degree|
|    Preschool|          1.0|no academic degree|
+-------------+-------------+------------------+

请注意,您需要将 | 用作 or 运算符,将 & 用作 and 运算符,将 ~ 用作 not运算符

SparkByExamples 有非常 good description of this function

但是如果你真的有固定的值列表,那么使用 isin function 检查值是否在给定值列表中会更容易:

P.S。我推荐 grab this free book - Learning Spark, 2ed - 它会很好地介绍 Spark、它的功能等。

>>> from pyspark.sql.functions import col
>>> df = spark.createDataFrame([("Bachelors", 13.0), ("Masters", 14.0), ("Preschool", 1.0)], schema=["bildungsstand", "bildungslevel"])
>>> df2 = df.withColumn("academics_category", 
  when(col("bildungsstand").isin(["Bachelors","Masters"]), 
    "academic degree").otherwise("no academic degree"))
>>> df2.show()
+-------------+-------------+------------------+
|bildungsstand|bildungslevel|academics_category|
+-------------+-------------+------------------+
|    Bachelors|         13.0|   academic degree|
|      Masters|         14.0|   academic degree|
|    Preschool|          1.0|no academic degree|
+-------------+-------------+------------------+

我在我的案例中发现了问题。根数据集的字符串值中有空格。我用 trim 函数替换了空格并创建了一个新的 DataFrame。

dfAdult = dfAdult.withColumn("bildungsstand_trim",trim(col="bildungsstand"))