如何在带有Spark的Scala中使用countDistinct



根据DataBrick的博客,我已经尝试使用countDistinct函数,该函数应该在Spark 1.5中可用。然而,我得到了以下例外:

Exception in thread "main" org.apache.spark.sql.AnalysisException: undefined function countDistinct;

我发现,在Spark开发人员的邮件列表中,他们建议使用countdistinct函数来获得相同的结果,这应该由countDistinct产生:

count(distinct <columnName>)
// Instead
countDistinct(<columnName>)

因为我从聚合函数的名称列表中动态构建聚合表达式,所以我希望没有任何需要不同处理的特殊情况。

那么,有可能通过以下方式统一它吗:

  • 注册新的UDAF,该UDAF将是count(distinct columnName)的别名
  • 手动注册已经在Spark CountDistinct函数中实现,该函数可能来自以下导入:

    导入org.apache.spark.sqlcatalyst.expressions。{CountDistinctFunction,CountDistinct}

  • 还是用其他方式?

编辑:示例(删除了一些本地引用和不必要的代码):

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Column, SQLContext, DataFrame}
import org.apache.spark.sql.functions._
import scala.collection.mutable.ListBuffer

class Flattener(sc: SparkContext) {
  val sqlContext = new SQLContext(sc)
  def flatTable(data: DataFrame, groupField: String): DataFrame = {
    val flatteningExpressions = data.columns.zip(TypeRecognizer.getTypes(data)).
      flatMap(x => getFlatteningExpressions(x._1, x._2)).toList
    data.groupBy(groupField).agg (
      expr(s"count($groupField) as groupSize"),
      flatteningExpressions:_*
    )
  }
  private def getFlatteningExpressions(fieldName: String, fieldType: DType): List[Column] = {
    val aggFuncs = getAggregationFunctons(fieldType)
    aggFuncs.map(f => expr(s"$f($fieldName) as ${fieldName}_$f"))
  }
  private def getAggregationFunctons(fieldType: DType): List[String] = {
    val aggFuncs = new ListBuffer[String]()
    if(fieldType == DType.NUMERIC) {
      aggFuncs += ("avg", "min", "max")
    }
    if(fieldType == DType.CATEGORY) {
      aggFuncs += "countDistinct"
    }
    aggFuncs.toList
  }
}

countDistinct可以以两种不同的形式使用:

df.groupBy("A").agg(expr("count(distinct B)")

df.groupBy("A").agg(countDistinct("B"))

然而,当您想在自定义UDAF(在Spark 1.5中实现为UserDefinedAggregateFunction)的同一列上使用它们时,这两种方法都不起作用:

// Assume that we have already implemented and registered StdDev UDAF 
df.groupBy("A").agg(countDistinct("B"), expr("StdDev(B)"))
// Will cause
Exception in thread "main" org.apache.spark.sql.AnalysisException: StdDev is implemented based on the new Aggregate Function interface and it cannot be used with functions implemented based on the old Aggregate Function interface.;

由于这些限制,看起来最合理的是将countDistinct实现为UDAF,它应该允许以相同的方式处理所有函数,并将countDistinct与其他UDAF一起使用。

示例实现如下所示:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class CountDistinct extends UserDefinedAggregateFunction{
  override def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
  }
  override def bufferSchema: StructType = StructType(
      StructField("items", ArrayType(StringType, true)) :: Nil
  )
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
  }
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Seq[String]()
  }
  override def deterministic: Boolean = true
  override def evaluate(buffer: Row): Any = {
    buffer.getSeq[String](0).length
  }
  override def dataType: DataType = IntegerType
}

不确定我是否真的理解你的问题,但这是countDistinct聚合函数的一个例子:

val values = Array((1, 2), (1, 3), (2, 2), (1, 2))
val myDf = sc.parallelize(values).toDF("id", "foo")
import org.apache.spark.sql.functions.countDistinct
myDf.groupBy('id).agg(countDistinct('foo) as 'distinctFoo) show
/**
+---+-------------------+
| id|COUNT(DISTINCT foo)|
+---+-------------------+
|  1|                  2|
|  2|                  1|
+---+-------------------+
*/

相关内容

  • 没有找到相关文章

最新更新