如何传递 Scala UserDefinedFunction,其中输出是复杂类型(使用 StructType 和 StructField)以从 Pyspark 使用

How can I pass a Scala UserDefinedFunction where output is a complex type (using StructType and StructField) to be used from Pyspark

所以,我想创建一个可以在 Pyspark 中使用的 Scala UDF。 我想要的是接受一个字符串列表作为 x 和一个字符串列表作为 y 和 获取所有字符串组合

所以如果我有 x = ["a","b] 和 y=["A","B"] 我希望输出为 out = [[a,A],[a,B ],[b,A],[b,B]]

我成功编写的 Scala 代码很简单

(x: Seq[String], y: Seq[String]) => {for (a <- x; b <-y) yield (a,b)}

我已经创建了一个执行此操作的 scala UDF。它适用于 Scala Spark。

我遇到的问题是试图从 pyspark 调用它。

为了做到这一点,我这样做了:

import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.UserDefinedFunction

import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.api.java.UDF1

class DualArrayExplode extends UDF2[Seq[String], Seq[String], UserDefinedFunction] {
  override  def call(x: Seq[String], y: Seq[String]):UserDefinedFunction = {
    // (worker node stuff)
    
  val DualArrayExplode =  (x: Seq[String], y: Seq[String]) => {for (a <- x; b <-y) yield (a,b)}
  val DualArrayExplodeUDF = (udf(DualArrayExplode))

  return DualArrayExplodeUDF

  }
}

object DualArrayExplode {
  def apply(): DualArrayExplode = {
    new DualArrayExplode()
  }
}

我创建了一个包含此代码和其他函数的 jar(可以正常工作) 此代码编译没有问题。

我在 scala spark 中使用它时的输出列类型是 Array(ArrayType(StructType(StructField(_1,StringType,true), StructField(_2,StringType,true)),true))

我的问题是我无法让它与 Pyspark 一起使用。注册此函数时无法定义正确的 return 类型。

这是我尝试注册 UDF 的方法

spark.udf.registerJavaFunction('DualArrayExplode', 
                               'blah.blah.blah.blah.blah.DualArrayExplode', <WHAT_TYPE_HERE???>)

Return 类型是可选的,但如果我省略它,那么结果是 [](一个空列表)

那么...我怎样才能在 pyspark 中实际使用这个 scala UDF?

我意识到可能会出现很多问题,因此我尽可能地描述了整个设置。

DualArrayExplode的声明是

class DualArrayExplode extends UDF2[Seq[String], Seq[String], UserDefinedFunction]

这意味着声明了一个 udf,它将两个字符串序列作为输入,returns 一个 udf。这个应该改成

class DualArrayExplode extends UDF2[Seq[String], Seq[String], Seq[(String,String)]] {
  override  def call(x: Seq[String], y: Seq[String]): Seq[(String,String)]= {
    // (worker node stuff)
    for (a <- x; b <-y) yield (a,b)
  }
}

udf 的 return 类型已更改为字符串元组序列。

现在可以使用

在 Pyspark 中注册此 udf
from pyspark.sql import types as T
rt = T.ArrayType(T.StructType([T.StructField("_1",T.StringType()), 
                               T.StructField("_2",T.StringType())]))
spark.udf.registerJavaFunction(name='DualArrayExplode', 
            javaClassName='blah.blah.DualArrayExplode', returnType=rt)