你能帮我避免广播大型查找表吗?我有一个测量表:
Measurement Value
x1 5.1
x2 8.9
x1 9.1
x3 4.4
x2 2.1
...
和配对列表:
P1 P2
x1 x2
x2 x3
...
任务是获取每对元素的所有值,并将它们放入魔术函数中。这就是我通过广播带有测量值的大表来解决它的方式。
case class Measurement(measurement: String, value: Double)
case class Candidate(c1: String, c2: String)
val measurements = Seq(Measurement("x1", 5.1), Measurement("x2", 8.9),
Measurement("x1", 9.1), Measurement("x3", 4.4))
val candidates = Seq(Candidate("x1", "x2"), Candidate("x2", "x3"))
// create data frames
val dfm = sqc.createDataFrame(measurements)
val dfc = sqc.createDataFrame(candidates)
// broadcast lookup table
val lookup = sc.broadcast(dfm.rdd.map(r => (r(0), r(1))).collect())
// udf: run magic test with every candidate
val magic: ((String, String) => Double) = (c1: String, c2: String) => {
val lt = lookup.value
val c1v = lt.filter(_._1 == c1).map(_._2).map(_.asInstanceOf[Double])
val c2v = lt.filter(_._1 == c2).map(_._2).map(_.asInstanceOf[Double])
new Foo().magic(c1v, c2v)
}
val sq1 = udf(magic)
val dfks = dfc.withColumn("magic", sq1(col("c1"), col("c2")))
正如您可以猜到的那样,我对解决方案不太满意。对于每对,我都会过滤查找表两次,这既不快也不优雅。我正在使用Spark 1.6.1。
另一种方法是使用 RDD 并加入。不确定在性能方面有什么更好。
case class Measurement(measurement: String, value: Double)
case class Candidate(c1: String, c2: String)
val measurements = Seq(Measurement("x1", 5.1), Measurement("x2", 8.9),
Measurement("x1", 9.1), Measurement("x3", 4.4))
val candidates = Seq(Candidate("x1", "x2"), Candidate("x2", "x3"))
val rdm = sc.parallelize(measurements).map(r => (r.measurement, r.value)).groupByKey().cache()
val rdc = sc.parallelize(candidates).map(r => (r.c1, r.c2)).cache()
val firstColJoin = rdc.join(rdm).values
val secondColJoin = firstColJoin.join(rdm).values
secondColJoin.map { case (c1v, c2v) => new Foo().magic(c1v, c2v) }
感谢您的所有评论。我阅读了评论,做了一些研究并研究了zero323帖子。
我目前的解决方案是使用两个joins
和一个UserDefinedAggregateFunction
:
object GroupValues extends UserDefinedAggregateFunction {
def inputSchema = new StructType().add("x", DoubleType)
def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
def dataType = ArrayType(DoubleType)
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, ArrayBuffer.empty[Double])
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (!input.isNullAt(0))
buffer.update(0, buffer.getSeq[Double](0) :+ input.getDouble(0))
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
buffer1.update(0, buffer1.getSeq[Double](0) ++ buffer2.getSeq[Double](0))
}
def evaluate(buffer: Row) = buffer.getSeq[Double](0)
}
// join data for candidate one
val j1 = dfc.join(dfm, dfc("c1") === dfm("measurement"))
// aggregate all c1 values to an array
val j1v = j1.groupBy(col("c1"), col("c2")).agg(GroupValues(col("value"))
.alias("c1-values"))
// join data for candidate two
val j2 = j1v.join(dfm, j1v("c2") === dfm("measurement"))
// aggregate all c2 values to an array
val j2v = j2.groupBy(col("c1"), col("c2"), col("c1-values"))
.agg(GroupValues(col("value")).alias("c2-values"))
下一步是使用 collect_list
而不是 UserDefinedAggregateFunction
.