如何避免在 Spark 中广播大型查找表



你能帮我避免广播大型查找表吗?我有一个测量表:

Measurement     Value
x1              5.1
x2              8.9
x1              9.1
x3              4.4
x2              2.1
...

和配对列表:

P1              P2
x1              x2
x2              x3
...

任务是获取每对元素的所有值,并将它们放入魔术函数中。这就是我通过广播带有测量值的大表来解决它的方式。

case class Measurement(measurement: String, value: Double)
case class Candidate(c1: String, c2: String)
val measurements = Seq(Measurement("x1", 5.1), Measurement("x2", 8.9), 
  Measurement("x1", 9.1), Measurement("x3", 4.4))
val candidates = Seq(Candidate("x1", "x2"), Candidate("x2", "x3"))
// create data frames
val dfm = sqc.createDataFrame(measurements)
val dfc = sqc.createDataFrame(candidates)
// broadcast lookup table
val lookup = sc.broadcast(dfm.rdd.map(r => (r(0), r(1))).collect())
// udf: run magic test with every candidate
val magic: ((String, String) => Double) = (c1: String, c2: String) => {
  val lt = lookup.value
  val c1v = lt.filter(_._1 == c1).map(_._2).map(_.asInstanceOf[Double])
  val c2v = lt.filter(_._1 == c2).map(_._2).map(_.asInstanceOf[Double])
  new Foo().magic(c1v, c2v)
}
val sq1 = udf(magic)
val dfks = dfc.withColumn("magic", sq1(col("c1"), col("c2")))

正如您可以猜到的那样,我对解决方案不太满意。对于每对,我都会过滤查找表两次,这既不快也不优雅。我正在使用Spark 1.6.1。

另一种方法是使用 RDD 并加入。不确定在性能方面有什么更好。

case class Measurement(measurement: String, value: Double)
case class Candidate(c1: String, c2: String)
val measurements = Seq(Measurement("x1", 5.1), Measurement("x2", 8.9), 
Measurement("x1", 9.1), Measurement("x3", 4.4))
val candidates = Seq(Candidate("x1", "x2"), Candidate("x2", "x3"))
val rdm = sc.parallelize(measurements).map(r => (r.measurement, r.value)).groupByKey().cache()
val rdc = sc.parallelize(candidates).map(r => (r.c1, r.c2)).cache()
val firstColJoin = rdc.join(rdm).values
val secondColJoin = firstColJoin.join(rdm).values
secondColJoin.map { case (c1v, c2v) => new Foo().magic(c1v, c2v) }

感谢您的所有评论。我阅读了评论,做了一些研究并研究了zero323帖子。

我目前的解决方案是使用两个joins和一个UserDefinedAggregateFunction

object GroupValues extends UserDefinedAggregateFunction {
  def inputSchema = new StructType().add("x", DoubleType)
  def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
  def dataType = ArrayType(DoubleType)
  def deterministic = true
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0, ArrayBuffer.empty[Double])
  }
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0))
      buffer.update(0, buffer.getSeq[Double](0) :+ input.getDouble(0))
  }
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1.update(0, buffer1.getSeq[Double](0) ++ buffer2.getSeq[Double](0))
  }
  def evaluate(buffer: Row) = buffer.getSeq[Double](0)
}
// join data for candidate one
val j1 = dfc.join(dfm, dfc("c1") === dfm("measurement"))
// aggregate all c1 values to an array
val j1v = j1.groupBy(col("c1"), col("c2")).agg(GroupValues(col("value"))
  .alias("c1-values"))
// join data for candidate two
val j2 = j1v.join(dfm, j1v("c2") === dfm("measurement"))
// aggregate all c2 values to an array
val j2v = j2.groupBy(col("c1"), col("c2"), col("c1-values"))
  .agg(GroupValues(col("value")).alias("c2-values"))

下一步是使用 collect_list 而不是 UserDefinedAggregateFunction .

相关内容

  • 没有找到相关文章

最新更新