access objects in pyspark user-defined function from outer scope, avoid PicklingError: Could not serialize object

access objects in pyspark user-defined function from outer scope, avoid PicklingError: Could not serialize object

如何避免在 pyspark 用户定义函数中初始化 class?这是一个例子。

正在创建一个 spark 会话和代表四个纬度和经度的 DataFrame。

import pandas as pd
from pyspark import SparkConf
from pyspark.sql import SparkSession

conf = SparkConf()
conf.set('spark.sql.execution.arrow.pyspark.enabled', 'true')
spark = SparkSession.builder.config(conf=conf).getOrCreate()

sdf = spark.createDataFrame(pd.DataFrame({
    'lat': [37, 42, 35, -22],
    'lng': [-113, -107, 127, 34]}))

这是 Spark DataFrame

+---+----+
|lat| lng|
+---+----+
| 37|-113|
| 42|-107|
| 35| 127|
|-22|  34|
+---+----+

通过 timezonefinder 包在每个纬度/经度处使用时区字符串丰富 DataFrame。 下面的代码 运行s 没有错误

from typing import Iterator
from timezonefinder import TimezoneFinder

def func(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for dx in iterator:
        tzf = TimezoneFinder()
        dx['timezone'] = [tzf.timezone_at(lng=a, lat=b) for a, b in zip(dx['lng'], dx['lat'])]
        yield dx
pdf = sdf.mapInPandas(func, schema='lat double, lng double, timezone string').toPandas()

上面的代码 运行 没有错误,并在下面创建了 pandas DataFrame。问题是 TimezoneFinder class 在用户定义的函数中初始化,这会造成瓶颈

In [4]: pdf
Out[4]:
    lat    lng         timezone
0  37.0 -113.0  America/Phoenix
1  42.0 -107.0   America/Denver
2  35.0  127.0       Asia/Seoul
3 -22.0   34.0    Africa/Maputo

问题是如何使此代码更像下面的 运行,其中 TimezoneFinder class 在用户定义的函数之外初始化一次。照原样,下面的代码会生成此错误 PicklingError: Could not serialize object: TypeError: cannot pickle '_io.BufferedReader' object

def func(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for dx in iterator:
        dx['timezone'] = [tzf.timezone_at(lng=a, lat=b) for a, b in zip(dx['lng'], dx['lat'])]
        yield dx
tzf = TimezoneFinder()
pdf = sdf.mapInPandas(func, schema='lat double, lng double, timezone string').toPandas()

更新 - 还尝试使用 functools.partial 和外部函数,但仍然收到相同的错误。也就是这个方法不行:

def outer(iterator, tzf):
    def func(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
        for dx in iterator:
            dx['timezone'] = [tzf.timezone_at(lng=a, lat=b) for a, b in zip(dx['lng'], dx['lat'])]
            yield dx
    return func(iterator)
tzf = TimezoneFinder()
outer = partial(outer, tzf=tzf)
pdf = sdf.mapInPandas(outer, schema='lat double, lng double, timezone string').toPandas()

您将需要在每个工作器上缓存该对象的实例。 你可以这样做

instance = [None]

def func(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    if instance[0] is None:
        instance[0] = TimezoneFinder()
    tzf = instance[0]
    for dx in iterator:
        dx['timezone'] = [tzf.timezone_at(lng=a, lat=b) for a, b in zip(dx['lng'], dx['lat'])]
        yield dx

请注意,要使其正常工作,您的函数将在模块中定义,以便为实例缓存提供存储空间。否则你将不得不把它挂在一些内置模块上,例如 os.instance = [].