如何在groupBy之后将值聚合到集合中



我有一个模式如下的数据帧:

[visitorId: string, trackingIds: array<string>, emailIds: array<string>]

正在寻找一种方法,通过visitorid将trackingId和emailId列附加在一起来对该数据帧进行分组(或汇总?)。例如,如果我的初始df看起来像:

visitorId   |trackingIds|emailIds
+-----------+------------+--------
|a158|      [666b]      |    [12]
|7g21|      [c0b5]      |    [45]
|7g21|      [c0b4]      |    [87]
|a158|      [666b, 777c]|    []

我希望我的输出df看起来像这个

visitorId   |trackingIds|emailIds
+-----------+------------+--------
|a158|      [666b,666b,777c]|      [12,'']
|7g21|      [c0b5,c0b4]     |      [45, 87]

尝试使用groupByagg运算符,但运气不佳。

Spark>=2.4

您可以将flatten udf替换为内置的flatten功能

import org.apache.spark.sql.functions.flatten

让其余部分保持原样。

Spark>=2.0,<2.4

这是可能的,但相当昂贵。使用您提供的数据:

case class Record(
    visitorId: String, trackingIds: Array[String], emailIds: Array[String])
val df = Seq(
  Record("a158", Array("666b"), Array("12")),
  Record("7g21", Array("c0b5"), Array("45")),
  Record("7g21", Array("c0b4"), Array("87")),
  Record("a158", Array("666b",  "777c"), Array.empty[String])).toDF

和一个辅助功能:

import org.apache.spark.sql.functions.udf
val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten)

我们可以用占位符填空:

import org.apache.spark.sql.functions.{array, lit, when}
val dfWithPlaceholders = df.withColumn(
  "emailIds", 
  when(size($"emailIds") === 0, array(lit(""))).otherwise($"emailIds"))

collect_listsflatten:

import org.apache.spark.sql.functions.{array, collect_list}
val emailIds = flatten(collect_list($"emailIds")).alias("emailIds")
val trackingIds = flatten(collect_list($"trackingIds")).alias("trackingIds")
df
  .groupBy($"visitorId")
  .agg(trackingIds, emailIds)
// +---------+------------------+--------+
// |visitorId|       trackingIds|emailIds|
// +---------+------------------+--------+
// |     a158|[666b, 666b, 777c]|  [12, ]|
// |     7g21|      [c0b5, c0b4]|[45, 87]|
// +---------+------------------+--------+

使用静态类型的Dataset:

df.as[Record]
  .groupByKey(_.visitorId)
  .mapGroups { case (key, vs) => 
    vs.map(v => (v.trackingIds, v.emailIds)).toArray.unzip match {
      case (trackingIds, emailIds) => 
        Record(key, trackingIds.flatten, emailIds.flatten)
  }}
// +---------+------------------+--------+
// |visitorId|       trackingIds|emailIds|
// +---------+------------------+--------+
// |     a158|[666b, 666b, 777c]|  [12, ]|
// |     7g21|      [c0b5, c0b4]|[45, 87]|
// +---------+------------------+--------+

Spark 1.x

您可以转换为RDD并分组

import org.apache.spark.sql.Row
dfWithPlaceholders.rdd
  .map {
     case Row(id: String, 
       trcks: Seq[String @ unchecked],
       emails: Seq[String @ unchecked]) => (id, (trcks, emails))
  }
  .groupByKey
  .map {case (key, vs) => vs.toArray.unzip match {
    case (trackingIds, emailIds) => 
      Record(key, trackingIds.flatten, emailIds.flatten)
  }}
  .toDF
// +---------+------------------+--------+
// |visitorId|       trackingIds|emailIds|
// +---------+------------------+--------+
// |     7g21|      [c0b5, c0b4]|[45, 87]|
// |     a158|[666b, 666b, 777c]|  [12, ]|
// +---------+------------------+--------+

@zero323的答案是非常完整,但Spark给了我们更多的灵活性。下面的解决方案怎么样?

import org.apache.spark.sql.functions._
inventory
  .select($"*", explode($"trackingIds") as "tracking_id")
  .select($"*", explode($"emailIds") as "email_id")
  .groupBy("visitorId")
  .agg(
    collect_list("tracking_id") as "trackingIds",
    collect_list("email_id") as "emailIds")

然而,这忽略了所有空的集合(因此还有一些改进的空间:))

您可以使用用户定义的聚合函数。

1) 使用名为customAgregation的scala类创建一个自定义UDAF。

package com.package.name
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scala.collection.JavaConverters._
class CustomAggregation() extends UserDefinedAggregateFunction {
// Input Data Type Schema
def inputSchema: StructType = StructType(Array(StructField("col5", ArrayType(StringType))))
// Intermediate Schema
def bufferSchema = StructType(Array(
StructField("col5_collapsed",  ArrayType(StringType))))
// Returned Data Type .
def dataType: DataType = ArrayType(StringType)
// Self-explaining
def deterministic = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = Array.empty[String] // initialize array
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
buffer(0) =
  if(!input.isNullAt(0))
    buffer.getList[String](0).toArray ++ input.getList[String](0).toArray
  else
    buffer.getList[String](0).toArray
}
  // Merge two partial aggregates
 def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
 buffer1(0) = buffer1.getList[String](0).toArray ++ buffer2.getList[String](0).toArray
}
 // Called after all the entries are exhausted.
 def evaluate(buffer: Row) = {
  buffer.getList[String](0).asScala.toList.distinct
 }
}

2) 然后将代码中的UDAF用作

//define UDAF
val CustomAggregation = new CustomAggregation()
DataFrame
    .groupBy(col1,col2,col3)
    .agg(CustomAggregation(DataFrame(col5))).show()

相关内容

  • 没有找到相关文章

最新更新