取Struct中双精度嵌套向量的平均值



我有一个列,它是类似的Spark DataFrame中的结构数组

|-- sTest: array (nullable = true)
|    |-- element: struct (containsNull = true)
|    |    |-- value: string (nullable = true)
|    |    |-- embed: array (nullable = true)
|    |    |    |-- element: integer (containsNull = true)

其具有可变数量的全部具有相等长度的嵌套数组("嵌入"(。

对于每一行,我想取这些嵌入的平均值,并将结果作为列分配给一个新的数据帧(现有的+新的列(。

我读到一些人使用explode,但这不是我想要的。我最终想为每一行做一个聚合,计算平均嵌入(array(float)(。

具有ArrayType(StructType(列的数据帧的最小示例:

val structureData = Seq(
Row(Seq(Row("value1 ", Seq(1, 2, 3)), Row("value1 ", Seq(4, 5, 6)))),
Row(Seq(Row("value2", Seq(4,5,6))), Row("value1 ", Seq(1, 1, 1)))
)
val structureSchema = new StructType()
.add("sTest", ArrayType(new StructType()
.add("value", StringType)
.add("embed", ArrayType(IntegerType))))

所需输出为

Row(2.5, 3.5, 4.5)
Row(2.5, 3, 3.5)

您的数据本质上看起来像一个矩阵,并且您正试图按列汇总矩阵,因此考虑使用org.apache.spark.ml.stat包中的Summarizer是很自然的。

输入数据:

case class sTest(value: String, embed: Seq[Int])
val df = Seq(
Tuple1(Seq(
sTest("value1", Seq(1, 2, 3)), 
sTest("value2", Seq(4, 5, 6))
)),
Tuple1(Seq(
sTest("value3", Seq(4, 5, 6)), 
sTest("value4", Seq(1, 1, 1))
))
) toDF("nested")

计算平均值:

import org.apache.spark.sql.functions._
import org.apache.spark.ml.linalg.{Vectors, Vector}
import org.apache.spark.ml.stat.Summarizer
val array2vecUdf = udf((array: Seq[Int]) => {
Vectors.dense(array.toArray.map(_.toDouble))
})
val vec2arrayUdf = udf((vec: Vector) => {
vec.toArray
})
val stage1 = df
// Create a rowid so we can explode, extract the embed field as a vector and collect.
.withColumn("rowid", monotonically_increasing_id)
.withColumn("exp", explode($"nested"))
.withColumn("embed", $"exp".getItem("embed"))
.withColumn("embed_vec", array2vecUdf($"embed"))
val avg = Summarizer.metrics("mean").summary($"embed_vec")
val stage2 = stage1
.groupBy("rowid")
.agg(avg.alias("avg_vec"))
// Convert back from vector to array.
.select(vec2arrayUdf($"avg_vec.mean").alias("avgs"))
stage2.show(false)

结果:

+---------------+
|avgs           |
+---------------+
|[2.5, 3.5, 4.5]|
|[2.5, 3.0, 3.5]|
+---------------+

最新更新