我使用Spark 2.0.0和dataframe。这是我的输入数据帧
| id | year | qty |
|----|-------------|--------|
| a | 2012 | 10 |
| b | 2012 | 12 |
| c | 2013 | 5 |
| b | 2014 | 7 |
| c | 2012 | 3 |
我想要的是
| id | year_2012 | year_2013 | year_2014 |
|----|-----------|-----------|-----------|
| a | 10 | 0 | 0 |
| b | 12 | 0 | 7 |
| c | 3 | 5 | 0 |
或
| id | yearly_qty |
|----|---------------|
| a | [10, 0, 0] |
| b | [12, 0, 7] |
| c | [3, 5, 0] |
我找到的最接近的解决方案是collect_list()
,但这个函数不提供列表的顺序。在我看来,解决方案应该是这样的:
data.groupBy('id').agg(collect_function)
是否有一种方法来生成这没有过滤出每个id使用循环?
第一个可以很容易地实现使用pivot
:
from itertools import chain
years = sorted(chain(*df.select("year").distinct().collect()))
df.groupBy("id").pivot("year", years).sum("qty")
可以进一步转换为数组形式:
from pyspark.sql.functions import array, col
(...
.na.fill(0)
.select("id", array(*[col(str(x)) for x in years]).alias("yearly_qty")))
直接获得第二个可能不值得大惊小怪,因为你必须先填补空白。不过你可以试试:
from pyspark.sql.functions import collect_list, struct, sort_array, broadcast
years_df = sc.parallelize([(x, ) for x in years], 1).toDF(["year"])
(broadcast(years_df)
.join(df.select("id").distinct())
.join(df, ["year", "id"], "leftouter")
.na.fill(0)
.groupBy("id")
.agg(sort_array(collect_list(struct("year", "qty"))).qty.alias("qty")))
还需要Spark 2.0+支持struct
收集。
这两种方法都非常昂贵,所以在使用它们时应该小心。根据经验,长比宽好。