Spark UDAF 泛型类型不匹配
Spark UDAF generics type mismatch
我正尝试在 Spark(2.0.1、Scala 2.11)上创建一个 UDAF,如下所示。这实际上是聚合元组并输出 Map
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
class mySumToMap[K, V](keyType: DataType, valueType: DataType) extends UserDefinedAggregateFunction {
override def inputSchema = new StructType()
.add("a_key", keyType)
.add("a_value", valueType)
override def bufferSchema = new StructType()
.add("buffer_map", MapType(keyType, valueType))
override def dataType = MapType(keyType, valueType)
override def deterministic = true
override def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = Map[K, V]()
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// input :: 0 = a_key (k), 1 = a_value
if ( !(input.isNullAt(0)) ) {
val a_map = buffer(0).asInstanceOf[Map[K, V]]
val k = input.getAs[K](0) // get the value of position 0 of the input as string (a_key)
// I've split these on purpose to show that return values are all of type V
val new_v1: V = a_map.getOrElse(k, 0.asInstanceOf[V])
val new_v2: V = input.getAs[V](1)
val new_v: V = new_v1 + new_v2
buffer(0) = if (new_v != 0) a_map + (k -> new_v) else a_map - k
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val map1: Map[K, V] = buffer1(0).asInstanceOf[Map[K, V]]
val map2: Map[K, V] = buffer2(0).asInstanceOf[Map[K, V]]
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
}
override def evaluate(buffer: Row) = buffer(0).asInstanceOf[Map[K, V]]
}
但是当我编译这个的时候,我看到了下面的错误:
<console>:74: error: type mismatch;
found : V
required: String
val new_v: V = new_v1 + new_v2
^
<console>:84: error: type mismatch;
found : V
required: String
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
我做错了什么?
编辑: 对于那些将此标记为 重复的人 - 这不是那个问题的重复,因为那个问题不处理 Map 数据类型。上面的代码对于使用 Map 数据类型所面临的问题非常具体和完整。
将类型限制为 Numeric[_]
class mySumToMap[K, V: Numeric](keyType: DataType, valueType: DataType)
extends UserDefinedAggregateFunction {
...
使用Implicitly
在运行时获取它:
val n = implicitly[Numeric[V]]
并使用其 plus
方法代替 +
和 zero
代替 0
buffer1(0) = map1 ++ map2.map{
case (k,v) => k -> n.plus(v, map1.getOrElse(k, n.zero))
}
要支持更广泛的类型集,您可以使用 cats
Monoid
:
import cats._
import cats.implicits._
并调整代码:
class mySumToMap[K, V: Monoid](keyType: DataType, valueType: DataType)
extends UserDefinedAggregateFunction {
...
及以后:
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val map1: Map[K, V] = buffer1.getMap[K, V](0)
val map2: Map[K, V] = buffer2.getMap[K, V](0)
val m = implicitly[Monoid[Map[K, V]]]
buffer1(0) = m.combine(map1, map2)
}
我正尝试在 Spark(2.0.1、Scala 2.11)上创建一个 UDAF,如下所示。这实际上是聚合元组并输出 Map
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
class mySumToMap[K, V](keyType: DataType, valueType: DataType) extends UserDefinedAggregateFunction {
override def inputSchema = new StructType()
.add("a_key", keyType)
.add("a_value", valueType)
override def bufferSchema = new StructType()
.add("buffer_map", MapType(keyType, valueType))
override def dataType = MapType(keyType, valueType)
override def deterministic = true
override def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = Map[K, V]()
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// input :: 0 = a_key (k), 1 = a_value
if ( !(input.isNullAt(0)) ) {
val a_map = buffer(0).asInstanceOf[Map[K, V]]
val k = input.getAs[K](0) // get the value of position 0 of the input as string (a_key)
// I've split these on purpose to show that return values are all of type V
val new_v1: V = a_map.getOrElse(k, 0.asInstanceOf[V])
val new_v2: V = input.getAs[V](1)
val new_v: V = new_v1 + new_v2
buffer(0) = if (new_v != 0) a_map + (k -> new_v) else a_map - k
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val map1: Map[K, V] = buffer1(0).asInstanceOf[Map[K, V]]
val map2: Map[K, V] = buffer2(0).asInstanceOf[Map[K, V]]
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
}
override def evaluate(buffer: Row) = buffer(0).asInstanceOf[Map[K, V]]
}
但是当我编译这个的时候,我看到了下面的错误:
<console>:74: error: type mismatch;
found : V
required: String
val new_v: V = new_v1 + new_v2
^
<console>:84: error: type mismatch;
found : V
required: String
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
我做错了什么?
编辑: 对于那些将此标记为
将类型限制为 Numeric[_]
class mySumToMap[K, V: Numeric](keyType: DataType, valueType: DataType)
extends UserDefinedAggregateFunction {
...
使用Implicitly
在运行时获取它:
val n = implicitly[Numeric[V]]
并使用其 plus
方法代替 +
和 zero
代替 0
buffer1(0) = map1 ++ map2.map{
case (k,v) => k -> n.plus(v, map1.getOrElse(k, n.zero))
}
要支持更广泛的类型集,您可以使用 cats
Monoid
:
import cats._
import cats.implicits._
并调整代码:
class mySumToMap[K, V: Monoid](keyType: DataType, valueType: DataType)
extends UserDefinedAggregateFunction {
...
及以后:
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val map1: Map[K, V] = buffer1.getMap[K, V](0)
val map2: Map[K, V] = buffer2.getMap[K, V](0)
val m = implicitly[Monoid[Map[K, V]]]
buffer1(0) = m.combine(map1, map2)
}