首页 > 数据库技术 > 详细

SparkSQL(四)——用户自定义函数

时间:2020-03-01 20:37:49      阅读:284      评论:0      收藏:0      [点我收藏+]

 

用户自定义聚合函数

强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。

 

弱类型用户自定义聚合函数

通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均年龄的自定义聚合函数。

首先创建自定于聚合函数类,它要继承抽象类UserDefinedAggregateFunction,并实现其中的抽象方法:

package sparksql.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}

class MyAvg extends UserDefinedAggregateFunction{

  //输入数据结构
  override def inputSchema: StructType = {
    new StructType().add("age",LongType)
  }

  //缓冲区数据结构
  override def bufferSchema: StructType = {
    new StructType()
      .add("sum",LongType).add("count",LongType)
  }

  //输出数据类型
  override def dataType: DataType = DoubleType

  //是否稳定,对相同的输入始终给出相同的输出
  override def deterministic: Boolean = true


  /**
   * 缓冲区的初始化,有如下要求:
  * The contract should be that applying the merge function on two initial buffers should just
   *  return the initial buffer itself, i.e.
   * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer
   **/
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0l;
    buffer(1) = 0l;
  }

  //有数据达到时,缓冲区如何更新
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0)+input.getLong(0)
    buffer(1) = buffer.getLong(1)+1l
  }

  //如何合并两个缓冲区
  // Merges two aggregation buffers and stores the updated buffer values back to `buffer1
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0)+buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1)+buffer2.getLong(1)
  }

  /*
  Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
  aggregation buffer.
   */
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}

然后创建抽象类的实例并注册该实例,在sql语句中使用该抽象类,进行测试:

package sparksql.udf

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}

object Demo1 {

  def main(args: Array[String]): Unit = {

    //创建SparkConf()并设置App名称
    val conf = new SparkConf().setAppName("sparlsql").setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()

    val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json")
    userDF.createOrReplaceTempView("user")

    //创建聚合函数的实例
    val myavg = new MyAvg()
    //注册聚合函数
    spark.udf.register("udfavg",myavg)
    //使用聚合函数
    spark.sql("select udfavg(age) from user").show

  }

}

打印结果如下:

+----------+
|myavg(age)|
+----------+
|      21.0|
+----------+

 

 强类型用户自定义聚合函数

通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资。

首先创建自定义聚合类,并继承抽象类Aggregator,实现其中的抽象方法

package sparksql.udf

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

case class UserBean(name:String,age:Long)
case class Buffer(var sum:Long,var count:Long)

class MyAvg2 extends Aggregator[UserBean,Buffer,Double]{

  //定义buffer的初始值
  //A zero value for this aggregation. Should satisfy the property that any b + zero = b
  override def zero: Buffer = Buffer(0,0)

  /*
  定义数据到来时的计算规则
  Combine two values to produce a new value.  For performance, the function may modify `b` and
    return it instead of constructing new object for b.
   */
  override def reduce(b: Buffer, a: UserBean): Buffer = {
    b.sum = b.sum+a.age
    b.count = b.count+1l
    b
  }

  //合并
  override def merge(b1: Buffer, b2: Buffer): Buffer = {
    b1.sum = b1.sum+b2.sum
    b1.count = b1.count+b2.count
    b1
  }

  //最终的结果
  override def finish(reduction: Buffer): Double = {
    reduction.sum.toDouble/reduction.count
  }

  override def bufferEncoder: Encoder[Buffer] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

其中aggregator的泛型类型含义如下:

IN:输入数据类型

BUF:缓冲区数据类型

OUT:输出数据类型

 @tparam IN The input type for the aggregation.
 * @tparam BUF The type of the intermediate value of the reduction.
 * @tparam OUT The type of the final output result.
 * @since 1.6.0
 */abstract class Aggregator[-IN, BUF, OUT] extends Serializable {

然后创建自定义类的实例,并将他转化为TypedColumn类型,在select方法中使用该实例:

package sparksql.udf

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, TypedColumn}

object Demo2 {

  def main(args: Array[String]): Unit = {

    //创建SparkConf()并设置App名称
    val conf = new SparkConf().setAppName("sparlsql").setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()

    val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json")
    import spark.implicits._
    val userDS: Dataset[UserBean] = userDF.as[UserBean]

    //创建MyAvg2类的实例
    val myavg2 = new MyAvg2()

    //将该实例转化为TypedColumn类型的实例
    val udfavg: TypedColumn[UserBean, Double] = myavg2.toColumn.name("myavg")

    //使用
    userDS.select(udfavg).show

    spark.stop()
  }
}

 

坑:

以spark.read.json方式创建dataframe的时候,遇到数字类型,系统会自动地将其作为bigint来处理。后续如果把这些变量装进int型,则会抛出异常:

Exception in thread "main" org.apache.spark.sql.AnalysisException: Cannot up cast `age` from bigint to int as it may truncate。

比如你这样来声明case类UserBean:

case class UserBean(name:String,age:Int)

再这样来创建dataframe和dataset:

    val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json")
    import spark.implicits._
    val userDS: Dataset[UserBean] = userDF.as[UserBean]

则会发生异常。

因此要用long或者是bigint来接收。

SparkSQL(四)——用户自定义函数

原文:https://www.cnblogs.com/chxyshaodiao/p/12390940.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!