用户自定义聚合函数
强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。
弱类型用户自定义聚合函数
通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均年龄的自定义聚合函数。
首先创建自定于聚合函数类,它要继承抽象类UserDefinedAggregateFunction,并实现其中的抽象方法:
package sparksql.udf import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType} class MyAvg extends UserDefinedAggregateFunction{ //输入数据结构 override def inputSchema: StructType = { new StructType().add("age",LongType) } //缓冲区数据结构 override def bufferSchema: StructType = { new StructType() .add("sum",LongType).add("count",LongType) } //输出数据类型 override def dataType: DataType = DoubleType //是否稳定,对相同的输入始终给出相同的输出 override def deterministic: Boolean = true /** * 缓冲区的初始化,有如下要求: * The contract should be that applying the merge function on two initial buffers should just * return the initial buffer itself, i.e. * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer **/ override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0l; buffer(1) = 0l; } //有数据达到时,缓冲区如何更新 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getLong(0)+input.getLong(0) buffer(1) = buffer.getLong(1)+1l } //如何合并两个缓冲区 // Merges two aggregation buffers and stores the updated buffer values back to `buffer1 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0)+buffer2.getLong(0) buffer1(1) = buffer1.getLong(1)+buffer2.getLong(1) } /* Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given aggregation buffer. */ override def evaluate(buffer: Row): Any = { buffer.getLong(0).toDouble/buffer.getLong(1) } }
然后创建抽象类的实例并注册该实例,在sql语句中使用该抽象类,进行测试:
package sparksql.udf import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} object Demo1 { def main(args: Array[String]): Unit = { //创建SparkConf()并设置App名称 val conf = new SparkConf().setAppName("sparlsql").setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json") userDF.createOrReplaceTempView("user") //创建聚合函数的实例 val myavg = new MyAvg() //注册聚合函数 spark.udf.register("udfavg",myavg) //使用聚合函数 spark.sql("select udfavg(age) from user").show } }
打印结果如下:
+----------+
|myavg(age)|
+----------+
| 21.0|
+----------+
强类型用户自定义聚合函数
通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资。
首先创建自定义聚合类,并继承抽象类Aggregator,实现其中的抽象方法
package sparksql.udf import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.expressions.Aggregator case class UserBean(name:String,age:Long) case class Buffer(var sum:Long,var count:Long) class MyAvg2 extends Aggregator[UserBean,Buffer,Double]{ //定义buffer的初始值 //A zero value for this aggregation. Should satisfy the property that any b + zero = b override def zero: Buffer = Buffer(0,0) /* 定义数据到来时的计算规则 Combine two values to produce a new value. For performance, the function may modify `b` and return it instead of constructing new object for b. */ override def reduce(b: Buffer, a: UserBean): Buffer = { b.sum = b.sum+a.age b.count = b.count+1l b } //合并 override def merge(b1: Buffer, b2: Buffer): Buffer = { b1.sum = b1.sum+b2.sum b1.count = b1.count+b2.count b1 } //最终的结果 override def finish(reduction: Buffer): Double = { reduction.sum.toDouble/reduction.count } override def bufferEncoder: Encoder[Buffer] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }
其中aggregator的泛型类型含义如下:
IN:输入数据类型
BUF:缓冲区数据类型
OUT:输出数据类型
@tparam IN The input type for the aggregation. * @tparam BUF The type of the intermediate value of the reduction. * @tparam OUT The type of the final output result. * @since 1.6.0 */abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
然后创建自定义类的实例,并将他转化为TypedColumn类型,在select方法中使用该实例:
package sparksql.udf import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, TypedColumn} object Demo2 { def main(args: Array[String]): Unit = { //创建SparkConf()并设置App名称 val conf = new SparkConf().setAppName("sparlsql").setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json") import spark.implicits._ val userDS: Dataset[UserBean] = userDF.as[UserBean] //创建MyAvg2类的实例 val myavg2 = new MyAvg2() //将该实例转化为TypedColumn类型的实例 val udfavg: TypedColumn[UserBean, Double] = myavg2.toColumn.name("myavg") //使用 userDS.select(udfavg).show spark.stop() } }
坑:
以spark.read.json方式创建dataframe的时候,遇到数字类型,系统会自动地将其作为bigint来处理。后续如果把这些变量装进int型,则会抛出异常:
Exception in thread "main" org.apache.spark.sql.AnalysisException: Cannot up cast `age` from bigint to int as it may truncate。
比如你这样来声明case类UserBean:
case class UserBean(name:String,age:Int)
再这样来创建dataframe和dataset:
val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json") import spark.implicits._ val userDS: Dataset[UserBean] = userDF.as[UserBean]
则会发生异常。
因此要用long或者是bigint来接收。
原文:https://www.cnblogs.com/chxyshaodiao/p/12390940.html