Apache Spark SQL UDAF在窗口上显示了带有重复输入的奇数行为



我发现在Apache Spark SQL(版本2.2.0)中,当用户定义的聚合函数(UDAF)在窗口规范上使用的用户定义输入,UDAF确实(似乎)未正确调用evaluate方法。

我已经能够在Java和Scala,本地和集群中重现这种行为。下面的代码显示了一个示例,其中行被标记为False,如果它们在上一行的1秒内。

class ExampleUDAF(val timeLimit: Long) extends UserDefinedAggregateFunction {
  def deterministic: Boolean = true
  def inputSchema: StructType = StructType(Array(StructField("unix_time", LongType)))
  def dataType: DataType = BooleanType
  def bufferSchema = StructType(Array(
    StructField("previousKeepTime", LongType),
    StructField("keepRow", BooleanType)
  ))
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = 0L
    buffer(1) = false
  }
  def update(buffer: MutableAggregationBuffer, input: Row) = {    
    if (buffer(0) == 0L) {
      buffer(0) = input.getLong(0)
      buffer(1) = true
    } else {
      val timeDiff = input.getLong(0) - buffer.getLong(0)
      if (timeDiff < timeLimit) {
        buffer(1) = false
      } else {
        buffer(0) = input.getLong(0)
        buffer(1) = true
      }
    }
  }
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {} // Not implemented
  def evaluate(buffer: Row): Boolean = buffer.getBoolean(1)
 }
val timeLimit = 1000 // 1 second
val udaf = new ExampleUDAF(timeLimit)
val window = Window
  .orderBy(column("unix_time"))
  .partitionBy(column("category"))
val df = spark.createDataFrame(Arrays.asList(
    Row(1510000001000L, "a", true), 
    Row(1510000001000L, "a", false), 
    Row(1510000001000L, "a", false),
    Row(1510000001000L, "a", false),
    Row(1510000700000L, "a", true),
    Row(1510000700000L, "a", false)
  ), new StructType().add("unix_time", LongType).add("category", StringType).add("expected_result", BooleanType))
df.withColumn("actual_result", udaf(column("unix_time")).over(window)).show

以下是运行上述代码的输出。由于没有先前的数据,因此预计第一行将具有TRUE的actual_result值。当修改unix_time输入以在每个记录之间具有1毫秒时,UDAF按预期工作。

在UDAF方法中添加打印语句表明evaluate仅在最后一次调用一次,并且该缓冲区在update方法中正确更新为True,但这不是UDAF完成后返回的内容。

+-------------+--------+---------------+-------------+
|    unix_time|category|expected_result|actual_result|
+-------------+--------+---------------+-------------+
|1510000001000|       a|           true|        false|  // Should true as first element
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000700000|       a|           true|        false|  // Should be true as more than 1000 milliseconds between self and previous
|1510000700000|       a|          false|        false|
+-------------+--------+---------------+-------------+

我在窗口规格上使用时正确理解Spark的UDAF行为吗?如果没有,任何人都可以在这方面提供任何见解。如果我对Windows对UDAF行为的理解是正确的,那么这可能是Spark中的错误吗?谢谢。

UDAF的一个问题是,它没有指定要使用rowsBetween()运行窗口的行。如果没有rowsBetween()规范,则对于每一行,窗口函数将 ast ALL 请参见下面和之后的)行之前和之后的当前一行(包括当前一个)(在给定类别中)。因此,所有行的actual_result基本上将仅考虑您示例DataFrame中的最后两个行,而unix_time=1510000700000实际上将有效地返回所有行的false

使用window声明:

Window.partitionBy(col("category")).orderBy(col("unix_time")).rowsBetween(-1L, 0L)

您始终仅在上一行和当前行上寻找。先前的排首先。这会创建正确的输出。但是,由于使用相同unix_time的行排序不是唯一的,因此无法预测哪个行在具有相同unix_time的行之间具有值true

结果看起来像这样:

+-------------+--------+---------------+-------------+
|    unix_time|category|expected_result|actual_result|
+-------------+--------+---------------+-------------+
|1510000001000|       a|          false|         true|
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000001000|       a|           true|        false|
|1510000700000|       a|           true|         true|
|1510000700000|       a|          false|        false|
+-------------+--------+---------------+-------------+

update

进一步研究后,似乎提供orderBy列时,将所有元素取在当前行 当前行之前。并非像我之前说过的所有分区元素。另外,如果订单列包含每个重复的行的重复值窗口将包含所有重复的值。您可以通过这样做清楚地看到它:

val wA = Window.partitionBy(col("category")).orderBy(col("unix_time"))
val wB = Window.partitionBy(col("category"))
val wC = Window.partitionBy(col("category")).orderBy(col("unix_time")).rowsBetween(-1L, 0L)
df.withColumn("countRows", count(col("unix_time")).over(wA)).show()
df.withColumn("countRows", count(col("unix_time")).over(wB)).show()
df.withColumn("countRows", count(col("unix_time")).over(wC)).show()

将计算每个窗口中的元素数量。

  • 窗口wA将在每1510000001000行中有4个元素,每1510000700000。
  • 对于wB,当没有orderBy时,每个分区的窗口中都包含所有行,因此所有窗口都将具有6个元素。
  • 最后一个wC指定行的选择,因此不会在哪个窗口中选择哪个行歧义。在所有后续行的窗口中,第一行只有1个元素,2个元素。产生正确的结果。

我今天也学到了一些新知识:)

最新更新