引用函数外部时 PySpark UDF 问题

PySpark UDF issues when referencing outside of function

我遇到了错误

TypeError: cannot pickle '_thread.RLock' object

当我尝试应用以下代码时:

from pyspark.sql.types import *
from pyspark.sql.functions import *

data_1 = [('James','Smith','M',30),('Anna','Rose','F',41),
  ('Robert','Williams','M',62), 
]
data_2 = [('Junior','Smith','M',15),('Helga','Rose','F',33),
  ('Mike','Williams','M',77), 
]
columns = ["firstname","lastname","gender","age"]
df_1 = spark.createDataFrame(data=data_1, schema = columns)
df_2 = spark.createDataFrame(data=data_2, schema = columns)

def find_n_people_with_higher_age(x):
  return df_2.filter(df_2['age']>=x).count()

find_n_people_with_higher_age_udf = udf(find_n_people_with_higher_age, IntegerType())
df_1.select(find_n_people_with_higher_age_udf(col('category_id')))

这是一个good article on python UDF's.

我将其用作参考,因为我怀疑您 运行 遇到了序列化问题。我展示了整个段落以添加句子的上下文,但实际上问题在于序列化。

Performance Considerations

It’s important to understand the performance implications of Apache Spark’s UDF features. Python UDFs for example (such as our CTOF function) result in data being serialized between the executor JVM and the Python interpreter running the UDF logic – this significantly reduces performance as compared to UDF implementations in Java or Scala. Potential solutions to alleviate this serialization bottleneck include:

如果您仔细考虑您要问的问题,也许您会明白为什么这不起作用。您要求将数据帧(data_2)中的所有数据发送(序列化)到执行程序,然后执行程序将其序列化并将其发送到 python 进行解释。数据帧不序列化。所以这是你的问题,但如果他们这样做了,你就会向每个执行者发送一个完整的数据帧。您在此处的示例数据不是问题,但对于数万亿条记录,它会炸毁 JVM。

你问的是可行的,我只是想知道怎么做。 window 或 group by 可能就是诀窍。

添加其他数据:

from pyspark.sql import Window
from pyspark.sql.types import *
from pyspark.sql.functions import *
    data_1 = [('James','Smith','M',30),('Anna','Rose','F',41),
  ('Robert','Williams','M',62), 
]
    # add more data to make it more interesting.
    data_2 = [('Junior','Smith','M',15),('Helga','Rose','F',33),('Gia','Rose','F',34),
      ('Mike','Williams','M',77), ('John','Williams','M',77), ('Bill','Williams','F',79),
    ]
columns = ["firstname","lastname","gender","age"]
df_1 = spark.createDataFrame(data=data_1, schema = columns)
df_2 = spark.createDataFrame(data=data_2, schema = columns)

# dataframe to help fill in missing ages
ref = spark.range( 1, 110, 1).toDF("numbers").withColumn("count", lit(0)).withColumn("rolling_Count", lit(0))

    

countAges = df_2.groupby("age").count()
#this actually give you the short list of ages
rollingCounts = countAges.withColumn("rolling_Count", sum(col("count")).over(Window.partitionBy().orderBy(col("age").desc())))
#fill in missing ages and remove duplicates
filled = rollingCounts.union(ref).groupBy("age").agg(sum("count").alias("count"))
#add a rolling count across all ages
allAgeCounts = filled.withColumn("rolling_Count", sum(col("count")).over(Window.partitionBy().orderBy(col("age").desc())))
#do inner join because we've filled in all ages.
df_1.join(allAgeCounts, df_1.age == allAgeCounts.age, "inner").show()
+---------+--------+------+---+---+-----+-------------+                         
|firstname|lastname|gender|age|age|count|rolling_Count|
+---------+--------+------+---+---+-----+-------------+
|     Anna|    Rose|     F| 41| 41|    0|            3|
|   Robert|Williams|     M| 62| 62|    0|            3|
|    James|   Smith|     M| 30| 30|    0|            5|
+---------+--------+------+---+---+-----+-------------+

我通常不想在整个 table 上使用 window,但这里迭代的数据 <= 110,所以这是合理的。