>假设我们有数据帧虹膜:
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv')
df = spark.createDataFrame(df)
我需要按物种对萼片宽度执行一些聚合函数,例如获取每组 3 个最大值。
import pyspark.sql.functions as F
get_max_3 = F.udf(
lambda x: sorted(x)[-3:]
)
agged = df.groupBy('species').agg(F.collect_list('sepal_width').alias('sepal_width'))
agged = agged.withColumn('sepal_width', get_max_3('sepal_width'))
+----------+---------------+
| species| sepal_width|
+----------+---------------+
| virginica|[3.6, 3.8, 3.8]|
|versicolor|[3.2, 3.3, 3.4]|
| setosa|[4.1, 4.2, 4.4]|
+----------+---------------+
现在,我如何有效地将其转换为长格式的数据帧(意味着每个物种三行,每行对应一个值)?
有没有办法在不使用collect_list
的情况下做到这一点?
要将数据框转换回长格式,可以使用 explode
;但是,要使用此方法,您需要首先修复udf
,以便返回正确的类型:
from pyspark.sql.types import *
import pyspark.sql.functions as F
get_max_3 = F.udf(lambda x: sorted(x)[-3:], ArrayType(DoubleType()))
agged = agged.withColumn('sepal_width', get_max_3('sepal_width'))
agged.withColumn('sepal_width', F.explode(F.col('sepal_width'))).show()
+----------+-----------+
| species|sepal_width|
+----------+-----------+
| virginica| 3.6|
| virginica| 3.8|
| virginica| 3.8|
|versicolor| 3.2|
|versicolor| 3.3|
|versicolor| 3.4|
| setosa| 4.1|
| setosa| 4.2|
| setosa| 4.4|
+----------+-----------+
或者不收集为列表并分解,您可以先对sepal_width
列进行排名,然后根据rank
进行过滤:
df.selectExpr(
"species", "sepal_width",
"row_number() over (partition by species order by sepal_width desc) as rn"
).where(F.col("rn") <= 3).drop("rn").show()
+----------+-----------+
| species|sepal_width|
+----------+-----------+
| virginica| 3.8|
| virginica| 3.8|
| virginica| 3.6|
|versicolor| 3.4|
|versicolor| 3.3|
|versicolor| 3.2|
| setosa| 4.4|
| setosa| 4.2|
| setosa| 4.1|
+----------+-----------+