在 Spark SQL 中处理带有循环引用的模型?
Handling models with circular references in Spark SQL?
昨天(实际上是完整的日志)我试图找出一种优雅的方式来表示 Scala/Spark SQL 2.2.1[=29= 中具有循环引用的模型]
假设这是原始模型方法,当然,它不起作用(请记住,真实模型有数十个属性):
case class Branch(id: Int, branches: List[Branch] = List.empty)
case class Tree(id: Int, branches: List[Branch])
val trees = Seq(Tree(1, List(Branch(2, List.empty), Branch(3, List(Branch(4, List.empty))))))
val ds = spark.createDataset(trees)
ds.show
这是它抛出的错误:
java.lang.UnsupportedOperationException: cannot have circular references in class, but got the circular reference of class Branch
我知道我们的最大层级是 5。因此,作为一种解决方法,我虽然采用了类似的方法:
case class BranchLevel5(id: Int)
case class BranchLevel4(id: Int, branches: List[BranchLevel5] = List.empty)
case class BranchLevel3(id: Int, branches: List[BranchLevel4] = List.empty)
case class BranchLevel2(id: Int, branches: List[BranchLevel3] = List.empty)
case class BranchLevel1(id: Int, branches: List[BranchLevel2] = List.empty)
case class Tree(id: Int, branches: List[BranchLevel1])
当然可以。但这一点也不优雅,你可以想象实现的痛苦(可读性、耦合、维护、可用性、代码重复等)
所以问题是,如何处理模型中循环引用的情况?
如果您不介意使用 private API,那么我发现了一种行之有效的方法:将整个自引用结构视为用户定义的类型。我遵循这个答案的方法:.
package org.apache.spark.custom.udts // we're calling some private API so need to be under 'org.apache.spark'
import java.io._
import org.apache.spark.sql.types.{DataType, UDTRegistration, UserDefinedType}
class BranchUDT extends UserDefinedType[Branch] {
override def sqlType: DataType = org.apache.spark.sql.types.BinaryType
override def serialize(obj: Branch): Any = {
val bos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(bos)
oos.writeObject(obj)
bos.toByteArray
}
override def deserialize(datum: Any): Branch = {
val bis = new ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
val ois = new ObjectInputStream(bis)
val obj = ois.readObject()
obj.asInstanceOf[Branch]
}
override def userClass: Class[Branch] = classOf[Branch]
}
object BranchUDT {
def register() = UDTRegistration.register(classOf[Branch].getName, classOf[BranchUDT].getName)
}
只需创建并注册自定义 UDT,就是这样!
BranchUDT.register()
val trees = Seq(Tree(1, List(Branch(2, List.empty), Branch(3, List(Branch(4, List.empty))))))
val ds = spark.createDataset(trees)
ds.show(false)
//+---+----------------------------------------------------+
//|id |branches |
//+---+----------------------------------------------------+
//|1 |[Branch(2,List()), Branch(3,List(Branch(4,List())))]|
//+---+----------------------------------------------------+
昨天(实际上是完整的日志)我试图找出一种优雅的方式来表示 Scala/Spark SQL 2.2.1[=29= 中具有循环引用的模型]
假设这是原始模型方法,当然,它不起作用(请记住,真实模型有数十个属性):
case class Branch(id: Int, branches: List[Branch] = List.empty)
case class Tree(id: Int, branches: List[Branch])
val trees = Seq(Tree(1, List(Branch(2, List.empty), Branch(3, List(Branch(4, List.empty))))))
val ds = spark.createDataset(trees)
ds.show
这是它抛出的错误:
java.lang.UnsupportedOperationException: cannot have circular references in class, but got the circular reference of class Branch
我知道我们的最大层级是 5。因此,作为一种解决方法,我虽然采用了类似的方法:
case class BranchLevel5(id: Int)
case class BranchLevel4(id: Int, branches: List[BranchLevel5] = List.empty)
case class BranchLevel3(id: Int, branches: List[BranchLevel4] = List.empty)
case class BranchLevel2(id: Int, branches: List[BranchLevel3] = List.empty)
case class BranchLevel1(id: Int, branches: List[BranchLevel2] = List.empty)
case class Tree(id: Int, branches: List[BranchLevel1])
当然可以。但这一点也不优雅,你可以想象实现的痛苦(可读性、耦合、维护、可用性、代码重复等)
所以问题是,如何处理模型中循环引用的情况?
如果您不介意使用 private API,那么我发现了一种行之有效的方法:将整个自引用结构视为用户定义的类型。我遵循这个答案的方法:
package org.apache.spark.custom.udts // we're calling some private API so need to be under 'org.apache.spark'
import java.io._
import org.apache.spark.sql.types.{DataType, UDTRegistration, UserDefinedType}
class BranchUDT extends UserDefinedType[Branch] {
override def sqlType: DataType = org.apache.spark.sql.types.BinaryType
override def serialize(obj: Branch): Any = {
val bos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(bos)
oos.writeObject(obj)
bos.toByteArray
}
override def deserialize(datum: Any): Branch = {
val bis = new ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
val ois = new ObjectInputStream(bis)
val obj = ois.readObject()
obj.asInstanceOf[Branch]
}
override def userClass: Class[Branch] = classOf[Branch]
}
object BranchUDT {
def register() = UDTRegistration.register(classOf[Branch].getName, classOf[BranchUDT].getName)
}
只需创建并注册自定义 UDT,就是这样!
BranchUDT.register()
val trees = Seq(Tree(1, List(Branch(2, List.empty), Branch(3, List(Branch(4, List.empty))))))
val ds = spark.createDataset(trees)
ds.show(false)
//+---+----------------------------------------------------+
//|id |branches |
//+---+----------------------------------------------------+
//|1 |[Branch(2,List()), Branch(3,List(Branch(4,List())))]|
//+---+----------------------------------------------------+