Spark 用户定义聚合函数 UDAF

UDAF

UDAF(User Defined Aggregation Function)意为“用户定义聚合函数”。在了解 UDAF 之前,先了解什么是聚合函数:

下面的 SQL 从销售记录表查询店铺的销售额

SELECT SUM(sale) AS total_sale, shop FROM sale_records GROUP BY shop  

其中,函数 SUM 即聚合函数

聚合函数是以分组之后的行(Row)作为输入,经过运算,输出一个值

这里,分组后店铺的每笔销售额(sale)为数据,店铺总的销售额(total_sale)为输出

UDAF 即是除了内置的函数,用户定义的聚合函数

UDAF 示例

使用逗号连接分组内多个字段值

定义 UDAF 类

参考:http://spark.apache.org/docs/1.6.0/api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction

class MakeStringUDAF extends UserDefinedAggregateFunction {

  override def inputSchema: StructType = StructType(StructField("fileld_name", StringType) :: StructField("delimiter", StringType) :: Nil)

  override def bufferSchema: StructType = StructType(StructField("filelds", StringType) :: StructField("delimiter", StringType) :: Nil)

  override def dataType: DataType = StringType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ""
    buffer(1) = ""
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val fieldValue = input.getAs[String](0)
    val delimiter = input.getAs[String](1)

    if (fieldValue != null && fieldValue.length > 0) {
      val bufferFieldValue = buffer.getAs[String](0)
      if (bufferFieldValue != null && bufferFieldValue.length > 0) {
        buffer(0) = bufferFieldValue.concat(",").concat(fieldValue)
      } else buffer(0) = fieldValue
    }
    buffer(1) = delimiter
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val leftFieldValue = buffer1.getAs[String](0)
    val rightFieldValue = buffer2.getAs[String](0)

    var fields = ""
    if (rightFieldValue.length != 0) {
      if (leftFieldValue.length != 0) {
        fields = rightFieldValue.concat(",").concat(leftFieldValue)
      } else {
        fields = rightFieldValue
      }
    } else {
      if (leftFieldValue.length != 0) {
        fields = leftFieldValue
      }
    }
    buffer1(0) = fields
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[String](0)
  }

}

initialize 用于初始化

update 更新 buffer

merge 合并多个 buffer

evaluate 使用 buffer 进行运算,返回最终值

注册 UDAF

sqlContext.udf().register("make_string", new MakeStringUDAF());  

使用

SELECT make_string(category) FROM book GROUP BY id