嗨,我正在尝试在pyspark中编写代码以从数据帧创建列表。我在我的代码中使用 collect() 函数,但不确定它是否是从数据帧列中获取值过滤器列表的正确方法。由于 collect() 将数据带入数据节点,因此在大型数据帧(例如 10 GB)的情况下,这将是一个糟糕的选择。
下面是我的输入数据帧 -
[Row(parent=u'p1', child=u'c1'), Row(parent=u'p11', child=u'p1'),
Row(parent=u'p111', child=u'p11'), Row(parent=u'p2', child=u'c2'),
Row(parent=u'p22', child=u'p2'), Row(parent=u'p222', child=u'p22'),
Row(parent=u'p2222', child=u'p222')]
我想实现如下输出数据帧 -
[Row(parent=u'p2222', child1=u'p222', child2=u'p22', child3=u'p2',
child4=u'c2'), Row(parent=u'p111', child1=u'p11', child2=u'p1',
child3=u'c1', child4=None)]
以下是我编写的工作代码,但不确定它是否进行了优化,因为 spark 以优化处理而闻名
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql.types import *
customSchema = StructType([StructField('parent',StringType(),True),
StructField('child',StringType(),True)])
#loading data from a CSV file and creating a dataframe
mydata = sqlContext.load(source='com.databricks.spark.csv',path='/FileStore/tables/34v0qouq1507635707462/parent_child_input.csv',header=True,schema=customSchema)
mydata.registerTempTable('mydata')
#creating a list of values of column "Child" from the dataframe "mydata"
childlist = [x[1] for x in mydata.collect()]
#creating another dataframe with filter values of "Parent" column which are not present in childlist
level1 = mydata.selectExpr('parent','child as child1').where(~mydata.parent.isin(childlist))
i=1
#Function to create dataframe containing desired output as mentioned above
def getChild(level1,i):
cname = 'child'+str(i)
tmp = [x[i] for x in level1.collect() if x[i]]
tmp = list(set(tmp))
if tmp.count(None)==1:
tmp.remove(None)
level1.registerTempTable('level1')
if len(tmp)>0:
i+=1
ccname = 'child'+str(i)
querystr='select level1.*,mydata.child as ' +ccname+
' from level1 left outer join mydata on level1.'+cname+'=mydata.parent'
level1 = sqlContext.sql(querystr)
level1 = getChild(level1,i)
return level1
level1 = getChild(level1,i)
level1.drop('child5').show()
如果您的父级和子级在您输入数据时会有整数,我们可以使用该整数进行拆分和分组。尝试了我的方式,希望它有所帮助。
>>> df.show()
+-----+------+
|child|parent|
+-----+------+
| c1| p1|
| p1| p11|
| p11| p111|
| c2| p2|
| p2| p22|
| p22| p222|
| p222| p2222|
+-----+------+
>>> udf = F.udf(lambda x,y : (x,y),ArrayType(StringType()))
>>> udf1 = F.udf(lambda x : tuple(filter(str.isdigit,x))[0],StringType())
>>> df1 = df.select("*",udf1('parent').alias('group'),udf('parent','child').alias('set'))
>>> df1.show()
+-----+------+-----+-------------+
|child|parent|group| set|
+-----+------+-----+-------------+
| c1| p1| 1| [p1, c1]|
| p1| p11| 1| [p11, p1]|
| p11| p111| 1| [p111, p11]|
| c2| p2| 2| [p2, c2]|
| p2| p22| 2| [p22, p2]|
| p22| p222| 2| [p222, p22]|
| p222| p2222| 2|[p2222, p222]|
+-----+------+-----+-------------+
>>> udf2 = F.udf(lambda x :sorted(set(sum(x,[])),reverse=True),ArrayType(StringType()))
>>> df2 = df1.groupby('group').agg(udf2(F.collect_set('set')).alias('column'))
>>> df2.show(truncate=False)
+-----+--------------------------+
|group|column |
+-----+--------------------------+
|1 |[p111, p11, p1, c1] |
|2 |[p2222, p222, p22, p2, c2]|
+-----+--------------------------+
maxval = df2[[F.max(F.size('column'))]].first()[0]
schema = StructType([StructField("parent",StringType(),True),StructField("child1",StringType(),True),StructField("child2",StringType(),True),StructField("child3",StringType(),True),StructField("child4",StringType(),True)])
udf3 = F.udf(lambda x : x if len(x) == maxval else x+[None]*(maxval -len(x)),schema)
>>> df2.select("*",udf3('column').alias('merged')).select("merged.*").show()
+------+------+------+------+------+
|parent|child1|child2|child3|child4|
+------+------+------+------+------+
| p111| p11| p1| c1| null|
| p2222| p222| p22| p2| c2|
+------+------+------+------+------+