我有两个数据帧,使用groupby
后我在 agg 中使用collect_set()
。聚合后flatMap
生成的数组的最佳方法是什么。
schema = ['col1', 'col2', 'col3', 'col4']
a = [[1, [23, 32], [11, 22], [9989]]]
df1 = spark.createDataFrame(a, schema=schema)
b = [[1, [34], [43, 22], [888, 777]]]
df2 = spark.createDataFrame(b, schema=schema)
df = df1.union(
df2
).groupby(
'col1'
).agg(
collect_set('col2').alias('col2'),
collect_set('col3').alias('col3'),
collect_set('col4').alias('col4')
)
df.collect()
我得到这个作为输出:
[Row(col1=1, col2=[[34], [23, 32]], col3=[[11, 22], [43, 22]], col4=[[9989], [888, 777]])]
但是,我希望将其作为输出:
[Row(col1=1, col2=[23, 32, 34], col3=[11, 22, 43], col4=[9989, 888, 777])]
您可以使用
udf
:
from itertools import chain
from pyspark.sql.types import *
from pyspark.sql.functions import udf
flatten = udf(lambda x: list(chain.from_iterable(x)), ArrayType(IntegerType()))
df.withColumn('col2_flat', flatten('col2'))
如果没有 UDF,我认为这应该可以工作:
from pyspark.sql.functions import array_distinct, flatten
df.withColumn('col2_flat', array_distinct(flatten('col2')))
它将平展嵌套数组,然后进行重复数据删除。