我在Scala代码(Scala 2.13)中有一个这样的函数用于Spark
def getDataset[T <: Product: TypeTag](name:String): Dataset[T] = {
import spark.implicits._
val ds = spark.read.parquet(BASE_PATH + "/" + name).as[T]
ds.createOrReplaceTempView(name)
ds
}
现在我要转一个case类的Seq
,对于每个类,调用这个函数:
case class CLASS1(...)
case class CLASS2(...)
case class CLASS3(...)
Seq(CLASS1, CLASS2, CLASS3, ....).foreach {
c => getDataset[c??](name=c???)
}
我很难弄清楚确切的语法;案例类名称的符号,由foreach
中的变量c
表示,似乎代表了apply
方法(() => Product
)的类型。我真正需要的是要用作类型参数的case类的类型,以及case类的名称。
感觉我应该能够做到这一点-我在这里错过了什么?
看起来可以通过TypeTag
在运行时获取类型参数中使用的类型名称。
我正在讨论的解决方案是这样的:
def getDataset[T <: Product: TypeTag]: Dataset[T] = {
import spark.implicits._
val name = typeTag[T].tpe.typeSymbol.name.toString
val ds = spark.read.parquet(BASE_PATH + "/" + name).as[T]
ds.createOrReplaceTempView(name)
ds
}
然后是Seq(getDataset[CLASS1], getDataset[CLASS2], ...)
不是我所希望的,但至少我可以剪掉类名和字符串的复制粘贴。
您可以为case类定义自己的伴侣对象,并在每个对象中包含一个调用getDataset
的方法。例如,这应该可以工作(由我的心智编译器传递):
abstract class DatasetProvider[T <: Product : TypeTag] {
val name: String
def dataset: Dataset[T] =
getDataset[T](name)
}
case class Class1(...)
object Class1 extends DatasetProvider[Class1] {
override val name: String = "class1"
}
// and so forth for Class2, Class3
Seq(Class1, Class2, Class3).foreach { c =>
val ds = c.dataset
???
}
请注意,如果定义自己的伴侣对象,如果想将其用作函数,则必须显式地将其标记为函数:这可能是可取的,也可能不是可取的。
问题是您想在类型级别替换T
(在编译时已知),在值级别替换name
(在运行时已知)。
正常情况下,T
和name
不同时存在。
一种选择是将Seq(Class1, Class2, Class3)
的值级别替换为Class1 :: Class2 :: Class3 :: HNil
的类型级别,并使用Shapeless
import shapeless.{::, HNil, Poly0, Poly1, Typeable}
import shapeless.ops.hlist.FillWith
import scala.reflect.runtime.universe.{TypeTag, typeOf}
object datasetPoly extends Poly1 {
implicit def cse[T <: Product : TypeTag /*: Typeable*/]: Case.Aux[T, Dataset[T]] =
at(_ => getDataset[T](/*Typeable[T].describe*/typeOf[T].toString))
}
object nullPoly extends Poly0 {
implicit def cse[T >: Null]: Case0[T] = at(null)
}
FillWith[nullPoly.type, Class1 :: Class2 :: Class3 :: HNil].apply().map(datasetPoly)
或者您可以使用宏或运行时反射。在Seq(Class1, Class2, Class3)
中,Class1
、Class2
、Class3
是case类的伴侣对象。例如使用反射工具箱
import scala.reflect.runtime.universe.Quasiquote
import scala.reflect.runtime.{currentMirror => cm}
import scala.tools.reflect.ToolBox
val tb = cm.mkToolBox()
Seq(Class1, Class2, Class3).foreach(c => {
val classSymbol = cm.reflect(c).symbol.companion
tb.eval(q"App.getDataset[$classSymbol](${classSymbol.name.toString})")
})
你应该添加到build.sbt
libraryDependencies += scalaOrganization.value % "scala-reflect" % scalaVersion.value
libraryDependencies += scalaOrganization.value % "scala-compiler" % scalaVersion.value