我想在一个数据帧上编写一个UDF,它的操作是将特定行的值与同一组中的值进行比较,其中分组是通过多个键进行的。由于udf对单行进行操作,因此我想编写一个查询,将同一组中的值作为新列值返回。
例如在这个上面输入:
<表类>
id
categoryAB
categoryXY
value1
value2
tbody><<tr>1 X 0.2 对 2X 0.3 假 3 X 0.2 对 4B X 0.4 真正 5B X 0.1 对 6B Y 0.5 假 表类>
可能有一个更优化的方法,但这里是我通常的做法:
val df = Seq(
(1, "A", "X", 0.2, true),
(2, "A", "X", 0.3, false),
(3, "A", "X", 0.2, true),
(4, "B", "X", 0.4, true),
(5, "B", "X", 0.1, true),
(6, "B", "Y", 0.5, false)
).toDF("id", "categoryAB", "categoryXY", "value1", "value2")
df.join(
df.groupBy("categoryAB", "categoryXY")
.agg(
collect_list('value1) as "group1",
collect_list('value2) as "group2"
),
Seq("categoryAB", "categoryXY")
).show()
这个想法是,我分别计算categoryAB
和categoryXY
上的聚合,然后我将新数据帧连接到原始数据帧(确保df
被缓存,如果它是大量计算的结果,否则它将被计算两次)。