用Spark DataFrame中的另一个分类列的平均值替换列的空值



我有一个像这样的数据集

id    category     value
1     A            NaN
2     B            NaN
3     A            10.5
5     A            2.0
6     B            1.0

我想用各自类别的平均值填充NAN值。如下所示

id    category     value
1     A            4.16
2     B            0.5
3     A            10.5
5     A            2.0
6     B            1.0

我尝试使用

组来计算每个类别的第一个平均值
val df2 = dataFrame.groupBy(category).agg(mean(value)).rdd.map{
      case r:Row => (r.getAs[String](category),r.get(1))
    }.collect().toMap
    println(df2)

我得到了每个类别及其各自的平均值的地图。output: Map(A ->4.16,B->0.5) 现在,我尝试在SparkSQL中更新查询以填充列,但似乎SPQRKSQL DOSNT支持更新查询。我试图用数据框架填充零值,但没有这样做。我能做些什么?我们可以在大熊猫中所示的熊猫中进行相同的操作:如何用groupby的平均值填充零值?但是我该如何使用Spark DataFrame

最简单的解决方案是使用groupby并加入:

 val df2 = df.filter(!(isnan($"value"))).groupBy("category").agg(avg($"value").as("avg"))
 df.join(df2, "category").withColumn("value", when(col("value").isNaN, $"avg").otherwise($"value")).drop("avg")

请注意,如果所有NAN都有类别,将从结果中删除

的确,您不能使用 selectjoin之类的函数,但您可以使用 update dateframes。在这种情况下,您可以将分组结果保留为DataFrame,并将其(在category列上)加入原始结果,然后执行将NaN s替换为平均值的映射:

import org.apache.spark.sql.functions._
import spark.implicits._
// calculate mean per category:
val meanPerCategory = dataFrame.groupBy("category").agg(mean("value") as "mean")
// use join, select and "nanvl" function to replace NaNs with the mean values:
val result = dataFrame
  .join(meanPerCategory, "category")
  .select($"category", $"id", nanvl($"value", $"mean")).show()

我偶然发现了同一问题,并遇到了这篇文章。但是尝试了其他解决方案,即使用窗口函数。下面的代码在PYSPARK 2.4.3上测试(Spark 1.4可从窗口功能获得)。我相信这是更干净的解决方案。这篇文章很安静,但希望这个答案对他人有帮助。

from pyspark.sql import Window
from pyspark.sql.functions import *
df = spark.createDataFrame([(1,"A", None), (2,"B", None), (3,"A",10.5), (5,"A",2.0), (6,"B",1.0)], ['id', 'category', 'value'])
category_window = Window.partitionBy("category")
value_mean = mean("value0").over(category_window)
result = df
  .withColumn("value0", coalesce("value", lit(0)))
  .withColumn("value_mean", value_mean)
  .withColumn("new_value", coalesce("value", "value_mean"))
  .select("id", "category", "new_value")
result.show()

输出将如预期(有问题):

id  category    new_value       
1   A   4.166666666666667
2   B   0.5
3   A   10.5
5   A   2
6   B   1

最新更新