我的Spark数据帧中有一个类型为ArrayType[Struct]的字段。该字段的结构如下:
|-- categories: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- categoryId: integer (nullable = true)
| | |-- confidence: float (nullable = true)
|-- count: integer (nullable = true)
|-- naming: integer (nullable = true)
在源数据中有多个具有可信度的类别:
categoryId1| categoryConfidence1| categoryId2| categoryConfidence2| categoryId3| categoryConfidence3
1| 0.34| 2| 0.57| 3| 0.89
我想过滤掉类别和信心,只得到最大的信心,应该是这样的:
categoryId3| categoryConfidence3
3| 0.89
除了这些字段,我想把所有其他字段都保留在数据帧中。最终预期结果为:
|-- categories: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- categoryId: integer (nullable = true) //corresponding to max confidence value
| | |-- confidence: float (nullable = true) //only max confidence
|-- count: integer (nullable = true)
|-- naming: integer (nullable = true)
我目前的解决方案创建了一个额外的信心栏,这不是我需要的:
val categoriesWindow = Window.partitionBy("categories.categoryId", "categories.confidence")
val res = df.map(_
.withColumn("category", explode($"categories"))
.withColumn("confidence", max($"category.confidence").over(categoriesWindow))
.drop("categories"))
我可以做些什么来改进此解决方案?
假设Spark 2.4并使用高阶函数。
val df = Seq((10,"abc")).toDF("count","naming")
val df2 = df.withColumn("categories",expr(""" array(named_struct('categoryId',1,'confidence',0.34),
named_struct('categoryId',2,'confidence',0.57),
named_struct('categoryId',3,'confidence',0.89)
) """)).select("categories","count","naming")
df2.printSchema
root
|-- categories: array (nullable = false)
| |-- element: struct (containsNull = false)
| | |-- categoryId: integer (nullable = false)
| | |-- confidence: decimal(2,2) (nullable = false)
|-- count: integer (nullable = false)
|-- naming: string (nullable = true)
df2.show(false)
+---------------------------------+-----+------+
|categories |count|naming|
+---------------------------------+-----+------+
|[[1, 0.34], [2, 0.57], [3, 0.89]]|10 |abc |
+---------------------------------+-----+------+
val df3 = df2.withColumn("x_max", expr("""array_max(categories.categoryId) """))
df3.createOrReplaceTempView("cassie")
df3.show(false)
+---------------------------------+-----+------+-----+
|categories |count|naming|x_max|
+---------------------------------+-----+------+-----+
|[[1, 0.34], [2, 0.57], [3, 0.89]]|10 |abc |3 |
+---------------------------------+-----+------+-----+
spark.sql(""" select filter(categories, a -> a.categoryid=x_max ) category, count, naming from cassie """).show(false)
+-----------+-----+------+
|category |count|naming|
+-----------+-----+------+
|[[3, 0.89]]|10 |abc |
+-----------+-----+------+
更新-1:
如果你不需要视图,那么你可以使用下面的。
df2.withColumn("x_max", expr("""array_max(categories.confidence) """))
.withColumn("categories2", expr(""" filter(categories, a -> a.confidence=x_max ) """) )
.show
+--------------------+-----+------+-----+-----------+
| categories|count|naming|x_max|categories2|
+--------------------+-----+------+-----+-----------+
|[[1, 0.34], [2, 0...| 10| abc| 0.89|[[3, 0.89]]|
+--------------------+-----+------+-----+-----------+
更新-2
df2.withColumn("x_max", expr("""array_max(categories.confidence) """).cast("double")).printSchema
root
|-- categories: array (nullable = false)
| |-- element: struct (containsNull = false)
| | |-- categoryId: integer (nullable = false)
| | |-- confidence: decimal(2,2) (nullable = false)
|-- count: integer (nullable = false)
|-- naming: string (nullable = true)
|-- x_max: double (nullable = true
给出一个数据帧,如:
+--------------------+-----+------+
| categories|count|naming|
+--------------------+-----+------+
|[[1, 0.5], [2, 0.6]]| 5| 1|
+--------------------+-----+------+
在这些情况下,使用udf会有很大帮助。也许这可以用2.4版本中包含的数组函数来完成,尽管用这些函数处理结构类型可能有点棘手。创建一个新类型,如:
case class CategConfidence(categoryId: Int, confidence: Float)
Spark将数组[CategConfidence]的udf输出转换为
| |-- element: struct (containsNull = true)
| | |-- categoryId: integer (nullable = false)
| | |-- confidence: float (nullable = false)
使用udf,您可以使用简单的Scala:来处理数组
val udf_getmax = udf {array: Seq[Row] =>
val tupleArr = array.map(row => (row.getAs[Int]("categoryId"), row.getAs[Float]("confidence")))
Array(tupleArr.map { case (a,b) => CategConfidence(a,b) }.sortBy(-_.confidence).head)
}
然后删除类别列,您将获得:
val fd1 = df.withColumn("max_confidence_categories", udf_getmax(col("categories"))).drop("categories")
+-----+------+-------------------------+
|count|naming|max_confidence_categories|
+-----+------+-------------------------+
| 5| 1| [[2, 0.6]]|
+-----+------+-------------------------+