我有以下输入阵列
val bins = (("bin1",1.0,2.0),("bin2",3.0,4.0),("bin3",5.0,6.0))
基本上,字符串" bin1"是指滤波数据框的参考列中的值 - 基于边界条件从另一个列创建了一个新列,以在数组中剩余两个双打
中剩余两个双打var number_of_dataframes = bins.length
var ctempdf = spark.createDataFrame(sc.emptyRDD[Row],train_data.schema)
ctempdf = ctempdf.withColumn(colName,col(colName))
val t1 = System.nanoTime
for ( x<- 0 to binputs.length-1)
{
var tempdf = train_data.filter(col(refCol) === bins(x)._1)
//println(binputs(x)._1)
tempdf = tempdf.withColumn(colName,
when(col(colName) < bins(x)._2, bins(x)._2)
when(col(colName) > bins(x)._3, bins(x)._3)
otherwise(col(colName)))
ctempdf = ctempdf.union(tempdf)
val duration = (System.nanoTime - t1) / 1e9d
println(duration)
}
上面的代码在每一个增加的垃圾箱的价值中逐渐缓慢起作用 - 有什么方法我可以大幅度加速它 - 因为此代码再次嵌套在另一个循环中。
我使用了检查点/persist/cache,这些都没有帮助
这里不需要迭代联合。使用o.a.s.sql.functions.map
创建一个字面的map<string, struct<double, double>>
(用功能术语与延迟的string => struct<lower: dobule, upper: double>
一样)
import org.apache.spark.sql.functions._
val bins: Seq[(String, Double Double)] = Seq(
("bin1",1.0,2.0),("bin2",3.0,4.0),("bin3",5.0,6.0))
val binCol = map(bins.map {
case (key, lower, upper) => Seq(
lit(key),
struct(lit(lower) as "lower", lit(upper) as "upper"))
}.flatten: _*)
定义这样的表达式(这些是预定义的映射中的简单查找,因此binCol(col(refCol))
延迟了struct<lower: dobule, upper: double>
,其余的apply
采用lower
或upper
字段):
val lower = binCol(col(refCol))("lower")
val upper = binCol(col(refCol))("upper")
val c = col(colName)
并使用 CASE ... WHEN ...
(相当于if的火花)
val result = when(c.between(lower, upper), c)
.when(c < lower, lower)
.when(c > upper, upper)
选择并删除NULL
S:
df
.withColumn(colName, result)
// If value is still NULL it means we didn't find refCol key in binCol keys.
// To mimic .filter(col(refCol) === ...) we drop the rows
.na.drop(Seq(colName))
此解决方案假定开始时colName
中没有NULL
值,但是可以轻松调整以处理该假设不满足的情况。
如果该过程仍不清楚,我建议您逐步跟踪它:
spark.range(1).select(binCol as "map").show(false)
+------------------------------------------------------------+
|map |
+------------------------------------------------------------+
|[bin1 -> [1.0, 2.0], bin2 -> [3.0, 4.0], bin3 -> [5.0, 6.0]]|
+------------------------------------------------------------+
spark.range(1).select(binCol(lit("bin1")) as "value").show(false)
+----------+
|value |
+----------+
|[1.0, 2.0]|
+----------+
spark.range(1).select(binCol(lit("bin1"))("lower") as "value").show
+-----+
|value|
+-----+
| 1.0|
+-----+
并进一步指查询具有复杂类型的SQL SQL数据框。