在 scala 数据帧中合并地图



我有一个包含列 col1,col2,col3 的数据帧。 col1,col2 是字符串。col3 是下面定义的 Map[String,String]

|-- col3: map (nullable = true)
|    |-- key: string
|    |-- value: string (valueContainsNull = true)

我按 col1,col2 分组并使用 collect_list 进行聚合,以获得地图数组并存储在 col4 中。

df.groupBy($"col1", $"col2").agg(collect_list($"col3").as("col4"))
|-- col4: array (nullable = true)
|    |-- element: map (containsNull = true)
|    |    |-- key: string
|    |    |-- value: string (valueContainsNull = true)

但是,我想将所有地图组合在一起,将 col4 作为一张地图。 目前我有:

[[a->a1,b->b1],[c->c1]]

预期产出

[a->a1,b->b1,c->c1]

使用 udf 会很理想吗?

任何帮助,不胜感激。 谢谢。

您可以使用聚合和map_concat:

import org.apache.spark.sql.functions.{expr, collect_list}
val df = Seq(
(1, Map("k1" -> "v1", "k2" -> "v3")),
(1, Map("k3" -> "v3")),
(2, Map("k4" -> "v4")),
(2, Map("k6" -> "v6", "k5" -> "v5"))
).toDF("id", "data")
val mergeExpr = expr("aggregate(data, map(), (acc, i) -> map_concat(acc, i))")
df.groupBy("id").agg(collect_list("data").as("data"))
.select($"id", mergeExpr.as("merged_data"))
.show(false)
// +---+------------------------------+
// |id |merged_data                   |
// +---+------------------------------+
// |1  |[k1 -> v1, k2 -> v3, k3 -> v3]|
// |2  |[k4 -> v4, k6 -> v6, k5 -> v5]|
// +---+------------------------------+

使用map_concat我们通过aggregate内置函数连接数据列的所有Map项,该函数允许我们将聚合应用于列表对。

注意:当前在 Spark 2.4.5 上实现map_concat 它允许相同密钥共存。这很可能是一个错误,因为根据官方文档,这不是预期的行为。请注意这一点。

如果你想避免这种情况,你也可以选择UDF:

import org.apache.spark.sql.functions.{collect_list, udf}
val mergeMapUDF = udf((data: Seq[Map[String, String]]) => data.reduce(_ ++ _))
df.groupBy("id").agg(collect_list("data").as("data"))
.select($"id", mergeMapUDF($"data").as("merged_data"))
.show(false)

更新 (2022-08-27)

  1. 在 Spark 3.3.0 中,上述代码不起作用,并引发以下异常:
AnalysisException: cannot resolve 'aggregate(`data`, map(), lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))' due to data type mismatch: argument 3 requires map<null,null> type, however, 'lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable())' is of map<string,string> type.;
Project [id#110, aggregate(data#119, map(), lambdafunction(map_concat(cast(lambda acc#122 as map<string,string>), lambda i#123), lambda acc#122, lambda i#123, false), lambdafunction(lambda id#124, lambda id#124, false)) AS aggregate(data, map(), lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))#125]
+- Aggregate [id#110], [id#110, collect_list(data#111, 0, 0) AS data#119]
+- Project [_1#105 AS id#110, _2#106 AS data#111]
+- LocalRelation [_1#105, _2#106]

似乎map()被初始化为map<null,null>当map<string,string>是预期的。

要解决此问题,只需使用cast(map() as map<string, string>)显式将map()转换为map<string, string>

以下是更新的代码:

val mergeExpr = expr("aggregate(data, cast(map() as map<string,
string>), (acc, i) -> map_concat(acc, i))")
df.groupBy("id").agg(collect_list("data").as("data"))
.select($"id", mergeExpr)
.show(false)
  1. 关于相同的键错误,这似乎在最新版本中得到了修复。如果尝试添加相同的键,则会引发异常:
Caused by: RuntimeException: Duplicate map key k5 was found, please check the input data. If you want to remove the duplicated keys, you can set spark.sql.mapKeyDedupPolicy to LAST_WIN so that the key inserted at last takes precedence.

您可以在没有UDF的情况下实现它。 让我们创建数据帧:

val df = Seq(Seq(Map("a" -> "a1", "b" -> "b1"), Map("c" -> "c1", "d" -> "d1"))).toDF()
df.show(false)
df.printSchema()

输出:

+----------------------------------------+
|value                                   |
+----------------------------------------+
|[[a -> a1, b -> b1], [c -> c1, d -> d1]]|
+----------------------------------------+
root
|-- value: array (nullable = true)
|    |-- element: map (containsNull = true)
|    |    |-- key: string
|    |    |-- value: string (valueContainsNull = true)

如果您的数组包含 2 个元素,只需使用map_concat

df.select(map_concat('value.getItem(0), 'value.getItem(1))).show(false)

或者这个(我不知道如何动态地从 0 循环到'值数组类型列大小,这可能是最短的解决方案)

df.select(map_concat((for {i <- 0 to 1} yield 'value.getItem(i)): _*)).show(false)

否则,如果您的数组包含多个映射并且大小未知,您可以尝试以下方法:

val df2 = df.map(s => {
val list = s.getList[Map[String, String]](0)
var map = Map[String, String]()
for (i <- 0 to list.size() - 1) {
map = map ++ list.get(i)
}
map
})
df2.show(false)
df2.printSchema()

输出:

+------------------------------------+
|value                               |
+------------------------------------+
|[a -> a1, b -> b1, c -> c1, d -> d1]|
+------------------------------------+
root
|-- value: map (nullable = true)
|    |-- key: string
|    |-- value: string (valueContainsNull = true)

如果记录数较少,则可以分解并将它们收集为 struct() 并再次使用 map_from_entries

val df = Seq(Seq(Map("a" -> "a1", "b" -> "b1"), Map("c" -> "c1", "d" -> "d1"))).toDF()
df.show(false)
df.printSchema()
+----------------------------------------+
|value                                   |
+----------------------------------------+
|[{a -> a1, b -> b1}, {c -> c1, d -> d1}]|
+----------------------------------------+
root
|-- value: array (nullable = true)
|    |-- element: map (containsNull = true)
|    |    |-- key: string
|    |    |-- value: string (valueContainsNull = true)

df.createOrReplaceTempView("items")
val df2 = spark.sql("""
with t1 (select value from items),
t2 (select value, explode(value) m1 from t1 ),
t3 (select value, explode(m1) (k,v) from t2 ),
t4 (select value, struct(k,v) r1 from t3 ),
t5 (select collect_list(r1) r2 from t4 )
select map_from_entries(r2) merged_data from t5
""")
df2.show(false)
df2.printSchema
+------------------------------------+
|merged_data                         |
+------------------------------------+
|{a -> a1, b -> b1, c -> c1, d -> d1}|
+------------------------------------+
root
|-- merged_data: map (nullable = false)
|    |-- key: string
|    |-- value: string (valueContainsNull = true)

请注意,当我们在分组依据中使用"值"时,火花会抛出org.apache.spark.sql.AnalysisException: expression t4.value cannot be used as a grouping expression because its data type array<map<string,string>> is not an orderable data type.

让我们以 abiratsis 样本数据为例。在这里,我们必须在分组依据中使用 id 列,否则所有映射元素都将合并在一起。

val df = Seq(
(1, Map("k1" -> "v1", "k2" -> "v3")),
(1, Map("k3" -> "v3")),
(2, Map("k4" -> "v4")),
(2, Map("k6" -> "v6", "k5" -> "v5"))
).toDF("id", "data")
df.show(false)
df.printSchema()
+---+--------------------+
|id |data                |
+---+--------------------+
|1  |{k1 -> v1, k2 -> v3}|
|1  |{k3 -> v3}          |
|2  |{k4 -> v4}          |
|2  |{k6 -> v6, k5 -> v5}|
+---+--------------------+
root
|-- id: integer (nullable = false)
|-- data: map (nullable = true)
|    |-- key: string
|    |-- value: string (valueContainsNull = true)
df.createOrReplaceTempView("items")
val df2 = spark.sql("""
with t1 (select id, data from items),
t2 (select id, explode(data) (k,v) from t1 ),
t3 (select id, struct(k,v) r1 from t2 ),
t4 (select id, collect_list(r1) r2 from t3 group by id )
select id, map_from_entries(r2) merged_data from t4
""")
df2.show(false)
df2.printSchema
+---+------------------------------+
|id |merged_data                   |
+---+------------------------------+
|1  |{k1 -> v1, k2 -> v3, k3 -> v3}|
|2  |{k4 -> v4, k6 -> v6, k5 -> v5}|
+---+------------------------------+
root
|-- id: integer (nullable = false)
|-- merged_data: map (nullable = false)
|    |-- key: string
|    |-- value: string (valueContainsNull = true)

最新更新