将带有collect_list(列)的火花数据帧转换回长格式



>假设我们有数据帧虹膜:

import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv')
df = spark.createDataFrame(df)

我需要按物种对萼片宽度执行一些聚合函数,例如获取每组 3 个最大值。

import pyspark.sql.functions as F
get_max_3 = F.udf(
    lambda x: sorted(x)[-3:]
)
agged = df.groupBy('species').agg(F.collect_list('sepal_width').alias('sepal_width'))
agged = agged.withColumn('sepal_width', get_max_3('sepal_width'))
+----------+---------------+
|   species|    sepal_width|
+----------+---------------+
| virginica|[3.6, 3.8, 3.8]|
|versicolor|[3.2, 3.3, 3.4]|
|    setosa|[4.1, 4.2, 4.4]|
+----------+---------------+

现在,我如何有效地将其转换为长格式的数据帧(意味着每个物种三行,每行对应一个值)?

有没有办法在不使用collect_list的情况下做到这一点?

要将数据框转换回长格式,可以使用 explode ;但是,要使用此方法,您需要首先修复udf,以便返回正确的类型:

from pyspark.sql.types import *
import pyspark.sql.functions as F
get_max_3 = F.udf(lambda x: sorted(x)[-3:], ArrayType(DoubleType()))
agged = agged.withColumn('sepal_width', get_max_3('sepal_width'))
agged.withColumn('sepal_width', F.explode(F.col('sepal_width'))).show()
+----------+-----------+
|   species|sepal_width|
+----------+-----------+
| virginica|        3.6|
| virginica|        3.8|
| virginica|        3.8|
|versicolor|        3.2|
|versicolor|        3.3|
|versicolor|        3.4|
|    setosa|        4.1|
|    setosa|        4.2|
|    setosa|        4.4|
+----------+-----------+

或者不收集为列表并分解,您可以先对sepal_width列进行排名,然后根据rank进行过滤:

df.selectExpr(
    "species", "sepal_width", 
    "row_number() over (partition by species order by sepal_width desc) as rn"
).where(F.col("rn") <= 3).drop("rn").show()
+----------+-----------+
|   species|sepal_width|
+----------+-----------+
| virginica|        3.8|
| virginica|        3.8|
| virginica|        3.6|
|versicolor|        3.4|
|versicolor|        3.3|
|versicolor|        3.2|
|    setosa|        4.4|
|    setosa|        4.2|
|    setosa|        4.1|
+----------+-----------+

相关内容

  • 没有找到相关文章

最新更新