Spark collect_list 并限制结果列表
Spark collect_list and limit resulting list
我有以下格式的数据框:
name merged
key1 (internalKey1, value1)
key1 (internalKey2, value2)
...
key2 (internalKey3, value3)
...
我想要做的是按 name
对数据帧进行分组,收集列表并限制 列表的大小。
这就是我按 name
分组并收集列表的方式:
val res = df.groupBy("name")
.agg(collect_list(col("merged")).as("final"))
结果数据框类似于:
key1 [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list
key2 [(internalKey3, value3),...]
我想做的是限制每个键的生成列表的大小。我尝试了多种方法来做到这一点,但没有成功。我已经看到一些建议第 3 方解决方案的帖子,但我想避免这种情况。有办法吗?
您可以创建一个函数来限制聚合 ArrayType 列的大小,如下所示:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column
case class KV(k: String, v: String)
val df = Seq(
("key1", KV("internalKey1", "value1")),
("key1", KV("internalKey2", "value2")),
("key2", KV("internalKey3", "value3")),
("key2", KV("internalKey4", "value4")),
("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")
def limitSize(n: Int, arrCol: Column): Column =
array( (0 until n).map( arrCol.getItem ): _* )
df.
groupBy("name").agg( collect_list(col("merged")).as("final") ).
select( $"name", limitSize(2, $"final").as("final2") ).
show(false)
// +----+----------------------------------------------+
// |name|final2 |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+
因此,虽然 UDF 可以满足您的需求,但如果您正在寻找一种性能更高且对内存敏感的方法,那么实现此目的的方法就是编写 UDAF。不幸的是,UDAF API 实际上不像 spark 附带的聚合函数那样可扩展。但是,您可以使用他们的内部 APIs 构建内部函数来执行您需要的操作。
这里是 collect_list_limit
的一个实现,主要是 Spark 内部 CollectList
AggregateFunction 的副本。我只想扩展它,但它是一个案例 class。实际上,所需要的只是覆盖更新和合并方法以遵守传入的限制:
case class CollectListLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
if( buffer.size >= limit ) buffer
else if( other.size >= limit ) other
else ( buffer ++= other ).take( limit )
}
override def prettyName: String = "collect_list_limit"
}
要实际注册它,我们可以通过 Spark 的内部 FunctionRegistry
来完成,它接受名称和构建器,它实际上是一个使用提供的表达式创建 CollectListLimit
的函数:
val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )
编辑:
事实证明,仅当您尚未创建 SparkContext 时才将其添加到内置函数中才有效,因为它会在启动时生成不可变的克隆。如果你有一个现有的上下文,那么这应该可以通过反射添加它:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )
您可以使用 UDF。
这里是一个可能的例子,它不需要模式并且有一个有意义的减少:
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import scala.collection.mutable
object TestJob1 {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
("key", 1L, "gargamel"),
("key", 4L, "pe_gadol"),
("key", 2L, "zaam"),
("key1", 5L, "naval")
).toDF("group", "quality", "other")
rawDf.show(false)
rawDf.printSchema
val rawSchema = rawDf.schema
val fUdf = udf(reduceByQuality, rawSchema)
val aggDf = rawDf
.groupBy("group")
.agg(
count(struct("*")).as("num_reads"),
max(col("quality")).as("quality"),
collect_list(struct("*")).as("horizontal")
)
.withColumn("short", fUdf($"horizontal"))
.drop("horizontal")
aggDf.printSchema
aggDf.show(false)
}
def reduceByQuality= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val red = d.reduce((r1, r2) => {
val quality1 = r1.getAs[Long]("quality")
val quality2 = r2.getAs[Long]("quality")
val r3 = quality1 match {
case a if a >= quality2 =>
r1
case _ =>
r2
}
r3
})
red
}
}
这里有一个像你这样的数据的例子
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._
import scala.collection.mutable
object TestJob {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val df1 = Seq(
("key1", ("internalKey1", "value1")),
("key1", ("internalKey2", "value2")),
("key2", ("internalKey3", "value3")),
("key2", ("internalKey4", "value4")),
("key2", ("internalKey5", "value5"))
)
.toDF("name", "merged")
// df1.printSchema
//
// df1.show(false)
val res = df1
.groupBy("name")
.agg( collect_list(col("merged")).as("final") )
res.printSchema
res.show(false)
def f= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val d1 = d.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]].head
d1.toString
}
val fUdf = udf(f, StringType)
val d2 = res
.withColumn("d", fUdf(col("final")))
.drop("final")
d2.printSchema()
d2
.show(false)
}
}
我有以下格式的数据框:
name merged
key1 (internalKey1, value1)
key1 (internalKey2, value2)
...
key2 (internalKey3, value3)
...
我想要做的是按 name
对数据帧进行分组,收集列表并限制 列表的大小。
这就是我按 name
分组并收集列表的方式:
val res = df.groupBy("name")
.agg(collect_list(col("merged")).as("final"))
结果数据框类似于:
key1 [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list
key2 [(internalKey3, value3),...]
我想做的是限制每个键的生成列表的大小。我尝试了多种方法来做到这一点,但没有成功。我已经看到一些建议第 3 方解决方案的帖子,但我想避免这种情况。有办法吗?
您可以创建一个函数来限制聚合 ArrayType 列的大小,如下所示:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column
case class KV(k: String, v: String)
val df = Seq(
("key1", KV("internalKey1", "value1")),
("key1", KV("internalKey2", "value2")),
("key2", KV("internalKey3", "value3")),
("key2", KV("internalKey4", "value4")),
("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")
def limitSize(n: Int, arrCol: Column): Column =
array( (0 until n).map( arrCol.getItem ): _* )
df.
groupBy("name").agg( collect_list(col("merged")).as("final") ).
select( $"name", limitSize(2, $"final").as("final2") ).
show(false)
// +----+----------------------------------------------+
// |name|final2 |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+
因此,虽然 UDF 可以满足您的需求,但如果您正在寻找一种性能更高且对内存敏感的方法,那么实现此目的的方法就是编写 UDAF。不幸的是,UDAF API 实际上不像 spark 附带的聚合函数那样可扩展。但是,您可以使用他们的内部 APIs 构建内部函数来执行您需要的操作。
这里是 collect_list_limit
的一个实现,主要是 Spark 内部 CollectList
AggregateFunction 的副本。我只想扩展它,但它是一个案例 class。实际上,所需要的只是覆盖更新和合并方法以遵守传入的限制:
case class CollectListLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
if( buffer.size >= limit ) buffer
else if( other.size >= limit ) other
else ( buffer ++= other ).take( limit )
}
override def prettyName: String = "collect_list_limit"
}
要实际注册它,我们可以通过 Spark 的内部 FunctionRegistry
来完成,它接受名称和构建器,它实际上是一个使用提供的表达式创建 CollectListLimit
的函数:
val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )
编辑:
事实证明,仅当您尚未创建 SparkContext 时才将其添加到内置函数中才有效,因为它会在启动时生成不可变的克隆。如果你有一个现有的上下文,那么这应该可以通过反射添加它:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )
您可以使用 UDF。
这里是一个可能的例子,它不需要模式并且有一个有意义的减少:
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import scala.collection.mutable
object TestJob1 {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
("key", 1L, "gargamel"),
("key", 4L, "pe_gadol"),
("key", 2L, "zaam"),
("key1", 5L, "naval")
).toDF("group", "quality", "other")
rawDf.show(false)
rawDf.printSchema
val rawSchema = rawDf.schema
val fUdf = udf(reduceByQuality, rawSchema)
val aggDf = rawDf
.groupBy("group")
.agg(
count(struct("*")).as("num_reads"),
max(col("quality")).as("quality"),
collect_list(struct("*")).as("horizontal")
)
.withColumn("short", fUdf($"horizontal"))
.drop("horizontal")
aggDf.printSchema
aggDf.show(false)
}
def reduceByQuality= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val red = d.reduce((r1, r2) => {
val quality1 = r1.getAs[Long]("quality")
val quality2 = r2.getAs[Long]("quality")
val r3 = quality1 match {
case a if a >= quality2 =>
r1
case _ =>
r2
}
r3
})
red
}
}
这里有一个像你这样的数据的例子
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._
import scala.collection.mutable
object TestJob {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val df1 = Seq(
("key1", ("internalKey1", "value1")),
("key1", ("internalKey2", "value2")),
("key2", ("internalKey3", "value3")),
("key2", ("internalKey4", "value4")),
("key2", ("internalKey5", "value5"))
)
.toDF("name", "merged")
// df1.printSchema
//
// df1.show(false)
val res = df1
.groupBy("name")
.agg( collect_list(col("merged")).as("final") )
res.printSchema
res.show(false)
def f= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val d1 = d.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]].head
d1.toString
}
val fUdf = udf(f, StringType)
val d2 = res
.withColumn("d", fUdf(col("final")))
.drop("final")
d2.printSchema()
d2
.show(false)
}
}