从Apache Spark SQL中的用户定义聚合函数(UDAF)返回多个数组



我正试图使用Apache Spark SQL在Java中创建一个用户定义的聚合函数(UDAF),该函数在完成时返回多个数组。我在网上搜索过,找不到任何关于如何做到这一点的例子或建议。

我可以返回一个数组,但无法在返回多个数组的evaluate()方法中找到正确格式的数据。

UDAF确实可以工作,因为我可以在evaluate()方法中打印出数组,我只是不知道如何将这些数组返回到调用代码(如下所示供参考)。

UserDefinedAggregateFunction customUDAF = new CustomUDAF();
DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data");

我在下面包含了整个自定义UDAF类,但关键方法是dataType()和evaluate方法(),它们首先显示。

如有任何帮助或建议,我们将不胜感激。非常感谢。

public class CustomUDAF extends UserDefinedAggregateFunction {
    @Override
    public DataType dataType() {
        // TODO: Is this the correct way to return 2 arrays?
        return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
            .add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
    }
    @Override
    public Object evaluate(Row buffer) {
        // Data conversion
        List<Long> longList = new ArrayList<Long>(buffer.getList(0));
        List<Double> dataList = new ArrayList<Double>(buffer.getList(1));
        // Processing of data (omitted)
        // TODO: How to get data into format needed to return 2 arrays?
        return dataList;
    }
    @Override
    public StructType inputSchema() {
        return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType);
    }
    @Override
    public StructType bufferSchema() {
        return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
            .add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
    }
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, new ArrayList<Long>());
        buffer.update(1, new ArrayList<Double>());
    }
    @Override
    public void update(MutableAggregationBuffer buffer, Row row) {
        ArrayList<Long> longList = new ArrayList<Long>(buffer.getList(0));
        longList.add(row.getLong(0));
        ArrayList<Double> dataList = new ArrayList<Double>(buffer.getList(1));
        dataList.add(row.getDouble(1));
        buffer.update(0, longList);
        buffer.update(1, dataList);
    }
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        ArrayList<Long> longList = new ArrayList<Long>(buffer1.getList(0));
        longList.addAll(buffer2.getList(0));
        ArrayList<Double> dataList = new ArrayList<Double>(buffer1.getList(1));
        dataList.addAll(buffer2.getList(1));
        buffer1.update(0, longList);
        buffer1.update(1, dataList);
    }
    @Override
    public boolean deterministic() {
        return true;
    }
}

更新:根据zero323的答案,我可以使用返回两个数组

return new Tuple2<>(longArray, dataArray);

从中获取数据有点困难,但需要将DataFrame分解为Java Lists,然后将其构建回DataFrame。

据我所知,返回一个元组就足够了。在Scala:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
object DummyUDAF extends UserDefinedAggregateFunction {
  def inputSchema = new StructType().add("x", StringType)
  def bufferSchema = new StructType()
    .add("buff", ArrayType(LongType))
    .add("buff2", ArrayType(DoubleType))
  def dataType = new StructType()
    .add("xs", ArrayType(LongType))
    .add("ys", ArrayType(DoubleType))
  def deterministic = true 
  def initialize(buffer: MutableAggregationBuffer) = {}
  def update(buffer: MutableAggregationBuffer, input: Row) = {}
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {}
  def evaluate(buffer: Row) = (Array(1L, 2L, 3L), Array(1.0, 2.0, 3.0))
}
val df =  sc.parallelize(Seq(("a", 1), ("b", 2))).toDF("k", "v")
df.select(DummyUDAF($"k")).show(1, false)
// +---------------------------------------------------+
// |(DummyUDAF$(k),mode=Complete,isDistinct=false)     |
// +---------------------------------------------------+
// |[WrappedArray(1, 2, 3),WrappedArray(1.0, 2.0, 3.0)]|
// +---------------------------------------------------+

最新更新