SPARK java.lang.ClassCastException: org.apache.spark.sql.cat



我正在尝试使用agg函数,类型安全检查,我为数据集创建了一个case类并定义了它的模式

case class JosmSalesRecord(orderId: Long,
totalOrderSales : Double,
totalOrderCount: Long)
object JosmSalesRecord extends SparkSessionWrapper {
import sparkSession.implicits._
val schema: Option[StructType] = Some(StructType(Seq(
StructField("order_id", IntegerType ,true),
StructField("total_order_sales",DoubleType,true),
StructField("total_order_count", IntegerType,true)
)))
}

数据集

+----------+------------------+---------------+
|   orderId|   totalOrderSales|totalOrderCount|
+----------+------------------+---------------+
|1071131089|            433.68|              8|
|1071386263|  8848.42000000001|            343|
|1071439146|108.39999999999999|              8|
|1071349454|34950.400000000074|            512|
|1071283654|349.65000000000003|             27|
root
|-- orderId: long (nullable = false)
|-- totalOrderSales: double (nullable = false)
|-- totalOrderCount: long (nullable = false)

我正在应用

val pv = aggregateSum.agg(typed.sum[JosmSalesRecord](_.totalOrderSales),
typed.sumLong[JosmSalesRecord](_.totalOrderCount)).toDF("totalOrderSales", "totalOrderCount")

pv.show(),我得到

java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to com.pipeline.dataset.JosmSalesRecord
at com.FirstApp$$anonfun$7.apply(FirstApp.scala:78)
at org.apache.spark.sql.execution.aggregate.TypedSumDouble.reduce(typedaggregators.scala:32)
at org.apache.spark.sql.execution.aggregate.TypedSumDouble.reduce(typedaggregators.scala:30)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.agg_doConsume_1$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.serializefromobject_doConsume_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.mapelements_doConsume_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.deserializetoobject_doConsume_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.agg_doAggregateWithKeysOutput_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.agg_doAggregateWithoutKey_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55)
at org.apache.spark.scheduler.Task.run(Task.scala:123)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)

注意—i严格使用import org.apache.spark.sql.expressions.scalalang.type .sum键入安全函数。在应用import org.apache.spark.sql.functions的总和时,我得到了我的答案

如果aggregateSum是DataFrame,例如

val aggregateSum = Seq((1071131089L, 433.68, 8),(1071386263L,8848.42000000001,343)).toDF("orderId", "totalOrderSales", "totalOrderCount")

在当前版本的Spark(即3.x)中,不支持typed。你仍然可以像这样使用类型安全(Spark 3):

case class OutputRecord(totalOrderSales : Double, totalOrderCount: Long)

val pv = aggregateSum.as[JosmSalesRecord]
.groupByKey(_ => 1)
.mapGroups((k,vs) => vs.map((x:JosmSalesRecord) => (x.totalOrderSales, x.totalOrderCount))
.reduceOption((x, y) => (x._1 + y._1, x._2 + y._2))
.map{case (x, y) => OutputRecord(x,y)}
.getOrElse(OutputRecord(0.0, 0L)))
pv.show()

+----------------+---------------+
| totalOrderSales|totalOrderCount|
+----------------+---------------+
|9282.10000000001|            351|
+----------------+---------------+

如果你已经有Scala cats作为依赖项,那么你也可以这样做

最新更新