TLDR:我是pyspark的新手,我想我不太"sparky"当尝试做一堆聚合时。
我有一组数据,我需要知道每个分类列的每一层数据的比例。例如,如果我以以下开头:
|box|potato|country|
|1 |red |usa |
|1 |red |mexico |
|1 |yellow|canada |
|1 |red |canada |
|1 |red |mexico |
我想以:
结尾|box|potato.red|potato.yellow|country.usa|country.mexico|country.canada|
|1 |0.80 |0.20 |0.20 |0.40 |0.40 |
因此,我的输出是一个数据框,每个数据集有一行关于我的土豆箱,并且有一些列等于每个分类列的每个唯一级别的总和(在本例中为2 + 3 = 5)加上数据集标识符的1。
我想优化我的方法,作为实际的,非马铃薯数据,更大,有~100个分类列,从2到100个级别。
我认为我的主要问题是,当我在pyspark中实现这个时-我实际上并没有在我如何做到这一点上"sparky"。任何类型的优化使用更pyspark合适的方法是赞赏的。
当前的方法:
box_of_potatos = <readindata>
sz = box_of_potatos.count()
proportion = count(*)/sz
output_for_one_box = process_potatos(box_of_potatos, 1)
def process_potatos(df, dataset_id):
c_str = [c for c, t in df.dtypes if t.startswith('string')]
foo = spark.createDataFrame([Row(id = dataset_id)])
for i in c_str:
bar = categorical_lvl(df, i, dataset_id)
foo = foo.join(bar, ['dataset_id'])
return(foo)
def categorical_lvl(df, cat_attribute, dataset_id):
newCols = [cat_attribute, dataset_id]
smry = df_demo.groupBy(cat_attribute).agg(proportion)
smry = smry.toDF(*newCols).withColumn('column_modified', f.concat(f.lit(cat_attribute), f.lit("."), f.col(cat_attribute))).withColumn('id', f.lit(dataset_id))
smry = smry.groupBy('id').pivot('column_modified').avg(dataset_id).fillna(0)
return(smry)
Your DataFrame (df):
+---+------+-------+
|box|potato|country|
+---+------+-------+
| 1| red| usa|
| 1| red| mexico|
| 1|yellow| canada|
| 1| red| canada|
| 1| red| mexico|
+---+------+-------+
尝试PySparkpivot()
获得所需的结果
from pyspark.sql.functions import count
total_count = df.count()
# Helps in adding prefix to columns
rename_columns = lambda prefix, columns: [f"{prefix}.{column}" if column != "box" else column for column in columns]
pivot_potato = df.groupBy("box").pivot("potato").agg((count("potato") / total_count))
pivot_potato = pivot_potato.toDF(*rename_columns("potato", pivot_potato.columns))
pivot_country = df.groupBy("box").pivot("country").agg(count("country") / total_count)
pivot_country = pivot_country.toDF(*rename_columns("country", pivot_country.columns))
pivot_potato.join(pivot_country, "box", "inner").show()
输出+---+----------+-------------+--------------+--------------+-----------+
|box|potato.red|potato.yellow|country.canada|country.mexico|country.usa|
+---+----------+-------------+--------------+--------------+-----------+
| 1| 0.8| 0.2| 0.4| 0.4| 0.2|
+---+----------+-------------+--------------+--------------+-----------+