Spark 3.0中使用聚合器的通用UDAF



Spark 3.0已弃用UserDefinedAggregateFunction,我正试图使用Aggregator重写我的udaf。Aggregator的基本用法很简单,但我很难找到更通用的函数版本。

我将尝试用这个例子来解释我的问题,这个例子是collect_set的一个实现。这不是我的实际情况,但更容易解释问题:

class CollectSetDemoAgg(name: String) extends Aggregator[Row, Set[Int], Set[Int]] {
override def zero = Set.empty
override def reduce(b: Set[Int], a: Row) = b + a.getInt(a.fieldIndex(name))
override def merge(b1: Set[Int], b2: Set[Int]) = b1 ++ b2
override def finish(reduction: Set[Int]) = reduction
override def bufferEncoder = Encoders.kryo[Set[Int]]
override def outputEncoder = ExpressionEncoder()
}
// using it:
df.agg(new CollectSetDemoAgg("rank").toColumn as "result").show()

我更喜欢.toColumn而不是.udf.register,但这不是重点。

问题:我无法制作此聚合器的通用版本,它只能处理整数。

我尝试过:

class CollectSetDemo(name: String) extends Aggregator[Row, Set[Any], Set[Any]] 

它因错误而崩溃:

No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
java.lang.UnsupportedOperationException: No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567)

我不能使用CollectSetDemo[T],以防我不能正确使用outputEncoder。此外,当使用udaf时,我只能处理Spark数据类型、列等。

我还没有找到解决这种情况的好方法,但我能够在一定程度上解决它。代码部分借用了RowEncoder:

class CollectSetDemoAgg(name: String, fieldType: DataType) extends Aggregator[Row, Set[Any], Any] {
override def zero = Set.empty
override def reduce(b: Set[Any], a: Row) = b + a.get(a.fieldIndex(name))
override def merge(b1: Set[Any], b2: Set[Any]) = b1 ++ b2
override def finish(reduction: Set[Any]) = reduction.toSeq
override def bufferEncoder = Encoders.kryo[Set[Any]]
// now
override def outputEncoder = {
val mirror = ScalaReflection.mirror
val tt = fieldType match {
case ArrayType(LongType, _) => typeTag[Seq[Long]]
case ArrayType(IntegerType, _) => typeTag[Seq[Int]]
case ArrayType(StringType, _) => typeTag[Seq[String]]
// .. etc etc
case _ => throw new RuntimeException(s"Could not create encoder for ${name} column (${fieldType})")
}
val tpe = tt.in(mirror).tpe
val cls = mirror.runtimeClass(tpe)
val serializer = ScalaReflection.serializerForType(tpe)
val deserializer = ScalaReflection.deserializerForType(tpe)
new ExpressionEncoder[Any](serializer, deserializer, ClassTag[Any](cls))
}
}

我必须添加的一件事是聚合器中的结果数据类型参数。用法随后更改为:

df.agg(new CollectSetDemoAgg("rank", new ArrayType(IntegerType, true)).toColumn as "result").show()

我真的不喜欢它的结果,但它是有效的。我也欢迎任何关于如何改进它的建议

用泛型修改@Ramunas答案:

class CollectSetDemoAgg[T: TypeTag](name: String) extends Aggregator[Row, Set[T], Seq[T]] {
override def zero = Set.empty
override def reduce(b: Set[T], a: Row) = b + a.getAs[T](a.fieldIndex(name))
override def merge(b1: Set[T], b2: Set[T]) = b1 ++ b2
override def finish(reduction: Set[T]) = reduction.toSeq
override def bufferEncoder = Encoders.kryo[Set[T]]

override def outputEncoder = {
val tt = typeTag[Seq[T]]
val tpe = tt.in(mirror).tpe
val cls = mirror.runtimeClass(tpe)
val serializer = serializerForType(tpe)
val deserializer = deserializerForType(tpe)
new ExpressionEncoder[Seq[T]](serializer, deserializer, ClassTag[Seq[T]](cls))
}
}

最新更新