在 PySpark 中以分布式方式应用 udf 函数

Applying a udf function in a distributed fashion in PySpark

假设我有一个非常基本的 Spark DataFrame,它由几列组成,其中一列包含我要修改的值。

|| value   || lang ||
| 3        |  en   |
| 4        |  ua   |

说,我想为每个特定的 class 创建一个新列,我会在其中为给定值添加一个浮点数(虽然这与最后一个问题关系不大,但实际上我做了一个预测那里有 sklearn,但为了简单起见,我们假设我们正在添加东西,我的想法是我正在以某种方式修改值)。所以给定一个字典 classes={'1':2.0, '2':3.0} 我想为每个 class 有一列,我将 DF 的值添加到 class 的值,然后将其保存到 csv:

class_1.csv
|| value   || lang ||  my_class |  modified  ||
| 3        |  en   |     1      |     5.0    |  # this is 3+2.0
| 4        |  ua   |     1      |     6.0    |  # this is 4+2.0

class_2.csv
|| value   || lang ||  my_class |  modified  ||
| 3        |  en   |     2      |     6.0    |  # this is 3+3.0
| 4        |  ua   |     2      |     7.0    |  # this is 4+3.0

到目前为止,我有以下代码可以工作并修改每个已定义 class 的值,但它是通过 for 循环完成的,我正在为其寻找更高级的优化:

import pyspark
from pyspark import SparkConf, SparkContext
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
from pyspark.sql.functions import udf
from pyspark.sql.functions import lit

# create session and context
spark = pyspark.sql.SparkSession.builder.master("yarn").appName("SomeApp").getOrCreate()
conf = SparkConf().setAppName('Some_App').setMaster("local[*]")
sc = SparkContext.getOrCreate(conf)

my_df = spark.read.csv("some_file.csv")

# modify the value here
def do_stuff_to_column(value, separate_class):
    # do stuff to column, let's pretend we just add a specific value per specific class that is read from a dictionary
    class_dict = {'1':2.0, '2':3.0}  # would be loaded from somewhere
    return float(value+class_dict[separate_class])

 # iterate over each given class later
 class_dict = {'1':2.0, '2':3.0}   # in reality have more than 10 classes

 # create a udf function
 udf_modify = udf(do_stuff_to_column, FloatType())

 # loop over each class
 for my_class in class_dict:
    # create the column first with lit
    my_df2 = my_df.withColumn("my_class", lit(my_class))
    # modify using udf function
    my_df2 = my_df2.withColumn("modified", udf_modify("value","my_class"))
    # write to csv now
    my_df2.write.format("csv").save("class_"+my_class+".csv")

所以问题是,是否有一种 better/faster 方法可以在 for 循环中执行此操作?

我会使用某种形式的 join,在本例中为 crossJoin。这是一个 MWE:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([(3, 'en'), (4, 'ua')], ['value', 'lang'])
classes = spark.createDataFrame([(1, 2.), (2, 3.)], ['class_key', 'class_value'])
res = df.crossJoin(classes).withColumn('modified', F.col('value') + F.col('class_value'))
res.show()

对于另存为单独的 CSV,我认为没有比使用循环更好的方法了。