Spark在groupBy/aggregate中合并/组合数组



下面的Spark代码正确地演示了我想要做的事情,并使用一个小的演示数据集生成了正确的输出。

当我在大量生产数据上运行相同类型的代码时,我遇到了运行时问题。Spark作业在我的集群上运行了大约12个小时,然后失败了。

看一下下面的代码,把每一行都展开,只是为了合并回去,似乎效率很低。在给定的测试数据集中,在array_value_1中有三个值,在array_value_2中有三个值的第四行将爆炸为3*3或9个爆炸行。

那么,在一个更大的数据集中,一行有5个这样的数组列,每列有10个值,将爆炸到10^5个爆炸行?

查看提供的Spark函数,没有开箱即用的函数可以做我想要的。我可以提供一个用户定义的函数。这在速度上有什么缺点吗?

val sparkSession = SparkSession.builder.
  master("local")
  .appName("merge list test")
  .getOrCreate()
val schema = StructType(
  StructField("category", IntegerType) ::
    StructField("array_value_1", ArrayType(StringType)) ::
    StructField("array_value_2", ArrayType(StringType)) ::
    Nil)
val rows = List(
  Row(1, List("a", "b"), List("u", "v")),
  Row(1, List("b", "c"), List("v", "w")),
  Row(2, List("c", "d"), List("w")),
  Row(2, List("c", "d", "e"), List("x", "y", "z"))
)
val df = sparkSession.createDataFrame(rows.asJava, schema)
val dfExploded = df.
  withColumn("scalar_1", explode(col("array_value_1"))).
  withColumn("scalar_2", explode(col("array_value_2")))
// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19
logger.info(s"dfExploded.count()=${dfExploded.count()}")
val dfOutput = dfExploded.groupBy("category").agg(
  collect_set("scalar_1").alias("combined_values_2"),
  collect_set("scalar_2").alias("combined_values_2"))
dfOutput.show()

对于explode来说可能是低效的,但从根本上说,您尝试实现的操作是昂贵的。实际上,它只是另一个groupByKey,你在这里没有什么可以做得更好。因为你使用Spark> 2.0,你可以直接collect_list和平坦:

import org.apache.spark.sql.functions.{collect_list, udf}
val flatten_distinct = udf(
  (xs: Seq[Seq[String]]) => xs.flatten.distinct)
df
  .groupBy("category")
  .agg(
    flatten_distinct(collect_list("array_value_1")), 
    flatten_distinct(collect_list("array_value_2"))
  )

在Spark>= 2.4中,你可以用内置函数的组合来代替udf:

import org.apache.spark.sql.functions.{array_distinct, flatten}
val flatten_distinct = (array_distinct _) compose (flatten _)

也可以使用自定义Aggregator,但我怀疑这些都不会产生巨大的差异。

如果集合比较大,你希望有大量的重复,你可以尝试使用aggregateByKey与可变集合:

import scala.collection.mutable.{Set => MSet}
val rdd = df
  .select($"category", struct($"array_value_1", $"array_value_2"))
  .as[(Int, (Seq[String], Seq[String]))]
  .rdd
val agg = rdd
  .aggregateByKey((MSet[String](), MSet[String]()))( 
    {case ((accX, accY), (xs, ys)) => (accX ++= xs, accY ++ ys)},
    {case ((accX1, accY1), (accX2, accY2)) => (accX1 ++= accX2, accY1 ++ accY2)}
  )
  .mapValues { case (xs, ys) => (xs.toArray, ys.toArray) }
  .toDF

相关内容

  • 没有找到相关文章

最新更新