我有这样的数据帧。
+---+---+---+---+
| M| c2| c3| d1|
+---+---+---+---+
| 1|2_1|4_3|1_2|
| 2|3_4|4_5|1_2|
+---+---+---+---+
我必须转换这个 df 应该如下所示。在这里,c_max = max(c2,c3)
在使用 .ie 拆分后_
,所有列(c2
和c3
(都必须用_
拆分,然后获得最大值。
在实际场景中,我有 50 列,即c2,c3....c50
,需要从中获取最大值。
+---+---+---+---+------+
| M| c2| c3| d1|c_Max |
+---+---+---+---+------+
| 1|2_1|4_3|1_2| 4 |
| 2|3_4|4_5|1_2| 5 |
+---+---+---+---+------+
以下是使用 Spark>= 2.4.0 的expr
和内置数组函数的一种方法:
import org.apache.spark.sql.functions.{expr, array_max, array}
val df = Seq(
(1, "2_1", "3_4", "1_2"),
(2, "3_4", "4_5", "1_2")
).toDF("M", "c2", "c3", "d1")
// get max c for each c column
val c_cols = df.columns.filter(_.startsWith("c")).map{ c =>
expr(s"array_max(cast(split(${c}, '_') as array<int>))")
}
df.withColumn("max_c", array_max(array(c_cols:_*))).show
输出:
+---+---+---+---+-----+
| M| c2| c3| d1|max_c|
+---+---+---+---+-----+
| 1|2_1|3_4|1_2| 4|
| 2|3_4|4_5|1_2| 5|
+---+---+---+---+-----+
对于旧版本,请使用下一个代码:
val c_cols = df.columns.filter(_.startsWith("c")).map{ c =>
val c_ar = split(col(c), "_").cast("array<int>")
when(c_ar.getItem(0) > c_ar.getItem(1), c_ar.getItem(0)).otherwise(c_ar.getItem(1))
}
df.withColumn("max_c", greatest(c_cols:_*)).show
使用greatest
函数:
val df = Seq((1, "2_1", "3_4", "1_2"),(2, "3_4", "4_5", "1_2"),
).toDF("M", "c2", "c3", "d1")
// get all `c` columns and split by `_` to get the values after the underscore
val c_cols = df.columns.filter(_.startsWith("c"))
.flatMap{
c => Seq(split(col(c), "_").getItem(0).cast("int"),
split(col(c), "_").getItem(1).cast("int")
)
}
// apply greatest func
val c_max = greatest(c_cols: _*)
// add new column
df.withColumn("c_Max", c_max).show()
给:
+---+---+---+---+-----+
| M| c2| c3| d1|c_Max|
+---+---+---+---+-----+
| 1|2_1|3_4|1_2| 4|
| 2|3_4|4_5|1_2| 5|
+---+---+---+---+-----+
在 Spark>= 2.4.0 中,您可以使用array_max
函数并获取一些代码,这些代码即使包含 2 个以上值的列也可以使用。这个想法是从连接所有列(concat
列(开始。为此,我在我想要连接的所有列的数组上使用concat_ws
,我用array(cols.map(col) :_*)
.然后,我拆分结果字符串以获得包含所有列的所有值的大字符串数组。我把它投射到一个整数数组,然后调用array_max
。
val cols = (2 to 50).map("c"+_)
val result = df
.withColumn("concat", concat_ws("_", array(cols.map(col) :_*)))
.withColumn("array_of_ints", split('concat, "_").cast(ArrayType(IntegerType)))
.withColumn("c_max", array_max('array_of_ints))
.drop("concat", "array_of_ints")
在 Spark <2.4 中,您可以像这样定义自己array_max:
val array_max = udf((s : Seq[Int]) => s.max)
前面的代码不需要修改。但请注意,UDF 可能比预定义的 Spark SQL 函数慢。