在所有字段上使用select从结构中获取最大值Spark数据帧



我的Spark数据帧中有一个类型为ArrayType[Struct]的字段。该字段的结构如下:

|-- categories: array (nullable = true)
|    |-- element: struct (containsNull = true)
|    |    |-- categoryId: integer (nullable = true)
|    |    |-- confidence: float (nullable = true)
|-- count: integer (nullable = true)
|-- naming: integer (nullable = true)

在源数据中有多个具有可信度的类别:

categoryId1| categoryConfidence1| categoryId2| categoryConfidence2| categoryId3| categoryConfidence3
1| 0.34| 2| 0.57| 3| 0.89

我想过滤掉类别和信心,只得到最大的信心,应该是这样的:

categoryId3| categoryConfidence3
3| 0.89

除了这些字段,我想把所有其他字段都保留在数据帧中。最终预期结果为:

|-- categories: array (nullable = true)
|    |-- element: struct (containsNull = true)
|    |    |-- categoryId: integer (nullable = true) //corresponding to max confidence value
|    |    |-- confidence: float (nullable = true) //only max confidence
|-- count: integer (nullable = true)
|-- naming: integer (nullable = true)

我目前的解决方案创建了一个额外的信心栏,这不是我需要的:

val categoriesWindow = Window.partitionBy("categories.categoryId", "categories.confidence")
val res = df.map(_
.withColumn("category", explode($"categories"))
.withColumn("confidence", max($"category.confidence").over(categoriesWindow))
.drop("categories"))

我可以做些什么来改进此解决方案?

假设Spark 2.4并使用高阶函数。

val df = Seq((10,"abc")).toDF("count","naming")
val df2 = df.withColumn("categories",expr(""" array(named_struct('categoryId',1,'confidence',0.34),
named_struct('categoryId',2,'confidence',0.57),
named_struct('categoryId',3,'confidence',0.89)
) """)).select("categories","count","naming")
df2.printSchema
root
|-- categories: array (nullable = false)
|    |-- element: struct (containsNull = false)
|    |    |-- categoryId: integer (nullable = false)
|    |    |-- confidence: decimal(2,2) (nullable = false)
|-- count: integer (nullable = false)
|-- naming: string (nullable = true)
df2.show(false)
+---------------------------------+-----+------+
|categories                       |count|naming|
+---------------------------------+-----+------+
|[[1, 0.34], [2, 0.57], [3, 0.89]]|10   |abc   |
+---------------------------------+-----+------+
val df3 = df2.withColumn("x_max", expr("""array_max(categories.categoryId) """))
df3.createOrReplaceTempView("cassie")
df3.show(false)
+---------------------------------+-----+------+-----+
|categories                       |count|naming|x_max|
+---------------------------------+-----+------+-----+
|[[1, 0.34], [2, 0.57], [3, 0.89]]|10   |abc   |3    |
+---------------------------------+-----+------+-----+
spark.sql(""" select filter(categories, a -> a.categoryid=x_max ) category, count, naming  from cassie """).show(false)
+-----------+-----+------+
|category   |count|naming|
+-----------+-----+------+
|[[3, 0.89]]|10   |abc   |
+-----------+-----+------+

更新-1:

如果你不需要视图,那么你可以使用下面的。

df2.withColumn("x_max", expr("""array_max(categories.confidence) """))
.withColumn("categories2", expr(""" filter(categories, a -> a.confidence=x_max ) """) )
.show
+--------------------+-----+------+-----+-----------+
|          categories|count|naming|x_max|categories2|
+--------------------+-----+------+-----+-----------+
|[[1, 0.34], [2, 0...|   10|   abc| 0.89|[[3, 0.89]]|
+--------------------+-----+------+-----+-----------+

更新-2

df2.withColumn("x_max", expr("""array_max(categories.confidence) """).cast("double")).printSchema
root
|-- categories: array (nullable = false)
|    |-- element: struct (containsNull = false)
|    |    |-- categoryId: integer (nullable = false)
|    |    |-- confidence: decimal(2,2) (nullable = false)
|-- count: integer (nullable = false)
|-- naming: string (nullable = true)
|-- x_max: double (nullable = true

给出一个数据帧,如:

+--------------------+-----+------+
|          categories|count|naming|
+--------------------+-----+------+
|[[1, 0.5], [2, 0.6]]|    5|     1|
+--------------------+-----+------+

在这些情况下,使用udf会有很大帮助。也许这可以用2.4版本中包含的数组函数来完成,尽管用这些函数处理结构类型可能有点棘手。创建一个新类型,如:

case class CategConfidence(categoryId: Int, confidence: Float)

Spark将数组[CategConfidence]的udf输出转换为

|    |-- element: struct (containsNull = true)
|    |    |-- categoryId: integer (nullable = false)
|    |    |-- confidence: float (nullable = false)

使用udf,您可以使用简单的Scala:来处理数组

val udf_getmax = udf {array: Seq[Row] =>
val tupleArr = array.map(row => (row.getAs[Int]("categoryId"), row.getAs[Float]("confidence")))
Array(tupleArr.map { case (a,b) => CategConfidence(a,b) }.sortBy(-_.confidence).head)
}

然后删除类别列,您将获得:

val fd1 = df.withColumn("max_confidence_categories", udf_getmax(col("categories"))).drop("categories")
+-----+------+-------------------------+
|count|naming|max_confidence_categories|
+-----+------+-------------------------+
|    5|     1|               [[2, 0.6]]|
+-----+------+-------------------------+

最新更新