如何在 Spark SQL 中定义和使用用户定义的聚合函数?

How to define and use a User-Defined Aggregate Function in Spark SQL?

我知道如何在 Spark 中编写 UDF SQL:

def belowThreshold(power: Int): Boolean = {
        return power < -40
      }

sqlContext.udf.register("belowThreshold", belowThreshold _)

我可以做类似的事情来定义聚合函数吗?这是怎么做到的?

对于上下文,我想运行以下SQL查询:

val aggDF = sqlContext.sql("""SELECT span, belowThreshold(opticalReceivePower), timestamp
                                    FROM ifDF
                                    WHERE opticalReceivePower IS NOT null
                                    GROUP BY span, timestamp
                                    ORDER BY span""")

应该return类似于

Row(span1, false, T0)

我希望聚合函数告诉我 opticalReceivePower 在由 spantimestamp 定义的组中是否有任何值低于阈值。我需要编写与上面粘贴的 UDF 不同的 UDAF 吗?

支持的方法

Spark >= 3.0

Scala UserDefinedAggregateFunction 正在弃用(SPARK-30423 弃用 UserDefinedAggregateFunction)以支持注册 Aggregator

火花 >= 2.3

矢量化 udf(仅限 Python):

from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType

from pyspark.sql.types import *
import pandas as pd

df = sc.parallelize([
    ("a", 0), ("a", 1), ("b", 30), ("b", -50)
]).toDF(["group", "power"])

def below_threshold(threshold, group="group", power="power"):
    @pandas_udf("struct<group: string, below_threshold: boolean>", PandasUDFType.GROUPED_MAP)
    def below_threshold_(df):
        df = pd.DataFrame(
           df.groupby(group).apply(lambda x: (x[power] < threshold).any()))
        df.reset_index(inplace=True, drop=False)
        return df

    return below_threshold_

用法示例:

df.groupBy("group").apply(below_threshold(-40)).show()

## +-----+---------------+
## |group|below_threshold|
## +-----+---------------+
## |    b|           true|
## |    a|          false|
## +-----+---------------+

另见 Applying UDFs on GroupedData in PySpark (with functioning python example)

Spark >= 2.0(可选 1.6,但 API 略有不同):

可以在键入的 Datasets:

上使用 Aggregators
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

class BelowThreshold[I](f: I => Boolean)  extends Aggregator[I, Boolean, Boolean]
    with Serializable {
  def zero = false
  def reduce(acc: Boolean, x: I) = acc | f(x)
  def merge(acc1: Boolean, acc2: Boolean) = acc1 | acc2
  def finish(acc: Boolean) = acc

  def bufferEncoder: Encoder[Boolean] = Encoders.scalaBoolean
  def outputEncoder: Encoder[Boolean] = Encoders.scalaBoolean
}

val belowThreshold = new BelowThreshold[(String, Int)](_._2 < - 40).toColumn
df.as[(String, Int)].groupByKey(_._1).agg(belowThreshold)

Spark >= 1.5:

在 Spark 1.5 中,您可以像这样创建 UDAF,尽管它很可能是矫枉过正:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

object belowThreshold extends UserDefinedAggregateFunction {
    // Schema you get as an input
    def inputSchema = new StructType().add("power", IntegerType)
    // Schema of the row which is used for aggregation
    def bufferSchema = new StructType().add("ind", BooleanType)
    // Returned type
    def dataType = BooleanType
    // Self-explaining 
    def deterministic = true
    // zero value
    def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, false)
    // Similar to seqOp in aggregate
    def update(buffer: MutableAggregationBuffer, input: Row) = {
        if (!input.isNullAt(0))
          buffer.update(0, buffer.getBoolean(0) | input.getInt(0) < -40)
    }
    // Similar to combOp in aggregate
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1.update(0, buffer1.getBoolean(0) | buffer2.getBoolean(0))    
    }
    // Called on exit to get return value
    def evaluate(buffer: Row) = buffer.getBoolean(0)
}

用法示例:

df
  .groupBy($"group")
  .agg(belowThreshold($"power").alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark 1.4 解决方法

我不确定我是否正确理解了您的要求,但据我所知,普通的旧聚合在这里应该足够了:

val df = sc.parallelize(Seq(
    ("a", 0), ("a", 1), ("b", 30), ("b", -50))).toDF("group", "power")

df
  .withColumn("belowThreshold", ($"power".lt(-40)).cast(IntegerType))
  .groupBy($"group")
  .agg(sum($"belowThreshold").notEqual(0).alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark <= 1.4:

据我所知,目前(Spark 1.4.1)除了 Hive 之外不支持 UDAF。 Spark 1.5 应该可以实现(参见 SPARK-3947)。

不支持/内部方法

Spark 在内部使用了多个 类,包括 ImperativeAggregates and DeclarativeAggregates

它们仅供内部使用,可能会更改,恕不另行通知,因此您可能不想在生产代码中使用它,但只是为了完整性 BelowThresholdDeclarativeAggregate 可能是像这样实现(使用 Spark 2.2-SNAPSHOT 测试):

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

case class BelowThreshold(child: Expression, threshold: Expression) 
    extends  DeclarativeAggregate  {
  override def children: Seq[Expression] = Seq(child, threshold)

  override def nullable: Boolean = false
  override def dataType: DataType = BooleanType

  private lazy val belowThreshold = AttributeReference(
    "belowThreshold", BooleanType, nullable = false
  )()

  // Used to derive schema
  override lazy val aggBufferAttributes = belowThreshold :: Nil

  override lazy val initialValues = Seq(
    Literal(false)
  )

  override lazy val updateExpressions = Seq(Or(
    belowThreshold,
    If(IsNull(child), Literal(false), LessThan(child, threshold))
  ))

  override lazy val mergeExpressions = Seq(
    Or(belowThreshold.left, belowThreshold.right)
  )

  override lazy val evaluateExpression = belowThreshold
  override def defaultResult: Option[Literal] = Option(Literal(false))
} 

应该用 withAggregateFunction.

的等价物进一步包装

在 Spark(3.0+) 中定义和使用 UDF Java:

private static UDF1<Integer, Boolean> belowThreshold = (power) -> power < -40;

        

注册UDF:

SparkSession.builder()
.appName(appName)
.master(master)
.getOrCreate().udf().register("belowThreshold", belowThreshold, BooleanType);

通过 Spark 使用 UDF SQL:

spark.sql("SELECT belowThreshold('50')");