Spark scala 数据帧中某些特定列的最大值



我有这样的数据帧。

+---+---+---+---+
|  M| c2| c3| d1|
+---+---+---+---+
|  1|2_1|4_3|1_2|
|  2|3_4|4_5|1_2|
+---+---+---+---+

我必须转换这个 df 应该如下所示。在这里,c_max = max(c2,c3)在使用 .ie 拆分后_,所有列(c2c3(都必须用_拆分,然后获得最大值。

在实际场景中,我有 50 列,即c2,c3....c50,需要从中获取最大值。

+---+---+---+---+------+
|  M| c2| c3| d1|c_Max |
+---+---+---+---+------+
|  1|2_1|4_3|1_2|  4   |
|  2|3_4|4_5|1_2|  5   |
+---+---+---+---+------+

以下是使用 Spark>= 2.4.0 的expr和内置数组函数的一种方法:

import org.apache.spark.sql.functions.{expr, array_max, array}
val df = Seq(
(1, "2_1", "3_4", "1_2"),
(2, "3_4", "4_5", "1_2")
).toDF("M", "c2", "c3", "d1")
// get max c for each c column 
val c_cols = df.columns.filter(_.startsWith("c")).map{ c =>
expr(s"array_max(cast(split(${c}, '_') as array<int>))")
}
df.withColumn("max_c", array_max(array(c_cols:_*))).show

输出:

+---+---+---+---+-----+
|  M| c2| c3| d1|max_c|
+---+---+---+---+-----+
|  1|2_1|3_4|1_2|    4|
|  2|3_4|4_5|1_2|    5|
+---+---+---+---+-----+

对于旧版本,请使用下一个代码:

val c_cols = df.columns.filter(_.startsWith("c")).map{ c =>
val c_ar = split(col(c), "_").cast("array<int>")
when(c_ar.getItem(0) > c_ar.getItem(1), c_ar.getItem(0)).otherwise(c_ar.getItem(1))
}
df.withColumn("max_c", greatest(c_cols:_*)).show

使用greatest函数:

val df = Seq((1, "2_1", "3_4", "1_2"),(2, "3_4", "4_5", "1_2"),
).toDF("M", "c2", "c3", "d1")
// get all `c` columns and split by `_` to get the values after the underscore
val c_cols = df.columns.filter(_.startsWith("c"))
.flatMap{
c => Seq(split(col(c), "_").getItem(0).cast("int"), 
split(col(c), "_").getItem(1).cast("int")
)
} 
// apply greatest func
val c_max = greatest(c_cols: _*)
// add new column
df.withColumn("c_Max", c_max).show()

给:

+---+---+---+---+-----+
|  M| c2| c3| d1|c_Max|
+---+---+---+---+-----+
|  1|2_1|3_4|1_2|    4|
|  2|3_4|4_5|1_2|    5|
+---+---+---+---+-----+

在 Spark>= 2.4.0 中,您可以使用array_max函数并获取一些代码,这些代码即使包含 2 个以上值的列也可以使用。这个想法是从连接所有列(concat列(开始。为此,我在我想要连接的所有列的数组上使用concat_ws,我用array(cols.map(col) :_*).然后,我拆分结果字符串以获得包含所有列的所有值的大字符串数组。我把它投射到一个整数数组,然后调用array_max

val cols = (2 to 50).map("c"+_)
val result = df
.withColumn("concat", concat_ws("_", array(cols.map(col) :_*)))
.withColumn("array_of_ints", split('concat, "_").cast(ArrayType(IntegerType)))
.withColumn("c_max", array_max('array_of_ints))
.drop("concat", "array_of_ints")

在 Spark <2.4 中,您可以像这样定义自己array_max:

val array_max = udf((s : Seq[Int]) => s.max)

前面的代码不需要修改。但请注意,UDF 可能比预定义的 Spark SQL 函数慢。

最新更新