Spark中的条件连接



我有一个数据帧,其结构如下:

+----------+------+------+----------------+--------+------+
|      date|market|metric|aggregator_Value|type    |rank  |
+----------+------+------+----------------+--------+------+
|2018-08-05|    m1|   16 |              m1|median  |  1   |
|2018-08-03|    m1|    5 |              m1|median  |  2   |
|2018-08-01|    m1|   10 |              m1|mean    |  3   |
|2018-08-05|    m2|   35 |              m2|mean    |  1   |
|2018-08-03|    m2|   25 |              m2|mean    |  2   |
|2018-08-01|    m2|    5 |              m2|mean    |  3   |
+----------+------+------+----------------+--------+------+

在该数据帧中,排名列是根据市场列的日期和分组顺序计算的。像这个

val w_rank = Window.partitionBy("market").orderBy(desc("date"))
val outputDF2=outputDF1.withColumn("rank",rank().over(w_rank))

当秩=1时,我想提取输出数据帧中度量列的连接值,条件是如果秩=1行中的type="median",则将所有度量值与该市场连接起来。否则,如果秩=1行中的type="mean",则仅连接前两个度量值。像这个

+----------+------+------+----------------+--------+---------+
|      date|market|metric|aggregator_Value|type    |result   |
+----------+------+------+----------------+--------+---------+
|2018-08-05|    m1|   16 |              m1|median  |10|5|16  |
|2018-08-05|    m2|   35 |              m1|mean    |25|35    |
+----------+------+------+----------------+--------+---------+    

我怎样才能做到这一点?

您可以根据具体条件取消列metric,然后应用collect_listconcat_ws来获得所需的结果,如下所示:

val df = Seq(
("2018-08-05", "m1", 16, "m1", "median", 1),
("2018-08-03", "m1",  5, "m1", "median", 2),
("2018-08-01", "m1", 10, "m1", "mean",   3),
("2018-08-05", "m2", 35, "m2", "mean",   1),
("2018-08-03", "m2", 25, "m2", "mean",   2),
("2018-08-01", "m2",  5, "m2", "mean",   3)
).toDF("date", "market", "metric", "aggregator_value", "type", "rank")
val win_desc = Window.partitionBy("market").orderBy(desc("date"))
val win_asc = Window.partitionBy("market").orderBy(asc("date"))
df.
withColumn("rank1_type", first($"type").over(win_desc.rowsBetween(Window.unboundedPreceding, 0))).
withColumn("cond_metric", when($"rank1_type" === "mean" && $"rank" > 2, null).otherwise($"metric")).
withColumn("result", concat_ws("|", collect_list("cond_metric").over(win_asc))).
where($"rank" === 1).
show
// +----------+------+------+----------------+------+----+----------+-----------+-------+
// |      date|market|metric|aggregator_value|  type|rank|rank1_type|cond_metric| result|
// +----------+------+------+----------------+------+----+----------+-----------+-------+
// |2018-08-05|    m1|    16|              m1|median|   1|    median|         16|10|5|16|
// |2018-08-05|    m2|    35|              m2|  mean|   1|      mean|         35|  25|35|
// +----------+------+------+----------------+------+----+----------+-----------+-------+

最新更新