如何将数据帧中的结构映射到案例类



在我的应用程序中的某个时候,我有一个数据帧,其中包含从案例类创建的 Struct 字段。现在我想将其转换/映射回案例类类型:

import spark.implicits._
case class Location(lat: Double, lon: Double)
scala> Seq((10, Location(35, 25)), (20, Location(45, 35))).toDF
res25: org.apache.spark.sql.DataFrame = [_1: int, _2: struct<lat: double, lon: double>]
scala> res25.printSchema
root
 |-- _1: integer (nullable = false)
 |-- _2: struct (nullable = true)
 |    |-- lat: double (nullable = false)
 |    |-- lon: double (nullable = false)

和基本:

res25.map(r => {
   Location(r.getStruct(1).getDouble(0), r.getStruct(1).getDouble(1))
}).show(1)

看起来真的很脏有没有更简单的方法?

在 Spark 1.6+ 中,如果要保留保留的类型信息,请使用数据集 (DS(,而不是数据帧 (DF(。

import spark.implicits._
case class Location(lat: Double, lon: Double)
scala> Seq((10, Location(35, 25)), (20, Location(45, 35))).toDS
res25: org.apache.spark.sql.Dataset[(Int, Location)] = [_1: int, _2: struct<lat: double, lon: double>]
scala> res25.printSchema
root
 |-- _1: integer (nullable = false)
 |-- _2: struct (nullable = true)
 |    |-- lat: double (nullable = false)
 |    |-- lon: double (nullable = false)

它会给你Dataset[(Int, Location)].现在,如果你想再次回到它的案例类起源,那么只需这样做:

scala> res25.map(r => r._2).show(1)
+----+----+
| lat| lon|
+----+----+
|35.0|25.0|
+----+----+

但是,如果你想坚持使用DataFrame API,因为它是动态类型的性质,那么你必须像这样编码:

scala> res25.select("_2.*").map(r => Location(r.getDouble(0), r.getDouble(1))).show(1)
+----+----+
| lat| lon|
+----+----+
|35.0|25.0|
+----+----+

您还可以在 Row 中使用提取器模式,该模式会给您类似的结果,使用更惯用的 scala:

scala> res25.map { row =>
  (row: @unchecked) match {
    case Row(a: Int, Row(b: Double, c: Double)) => (a, Location(b, c))
  }
}
res26: org.apache.spark.sql.Dataset[(Int, Location)] = [_1: int, _2: struct<lat: double, lon: double>]
scala> res26.collect()
res27: Array[(Int, Location)] = Array((10,Location(35.0,25.0)), (20,Location(45.0,35.0)))

我认为其他答案已经确定,但也许他们可能需要其他措辞。

简而言之,不可能在数据帧中使用案例类,因为它们不涉及案例类并使用RowEncoder将内部 SQL 类型映射到Row

正如其他答案所说,您必须使用运算符将基于Row DataFrame转换为as Dataset

val df = Seq((10, Location(35, 25)), (20, Location(45, 35))).toDF
scala> val ds = df.as[(Int, Location)]
ds: org.apache.spark.sql.Dataset[(Int, Location)] = [_1: int, _2: struct<lat: double, lon: double>]
scala> ds.show
+---+-----------+
| _1|         _2|
+---+-----------+
| 10|[35.0,25.0]|
| 20|[45.0,35.0]|
+---+-----------+
scala> ds.printSchema
root
 |-- _1: integer (nullable = false)
 |-- _2: struct (nullable = true)
 |    |-- lat: double (nullable = false)
 |    |-- lon: double (nullable = false)
scala> ds.map[TAB pressed twice]
def map[U](func: org.apache.spark.api.java.function.MapFunction[(Int, Location),U],encoder: org.apache.spark.sql.Encoder[U]): org.apache.spark.sql.Dataset[U]
def map[U](func: ((Int, Location)) => U)(implicit evidence$6: org.apache.spark.sql.Encoder[U]): org.apache.spark.sql.Dataset[U]

相关内容

  • 没有找到相关文章

最新更新