如何在Scala中测试映射中抛出的异常



我有以下Scala函数

def throwError(spark: SparkSession,df:DataFrame): Unit = {
import spark.implicits._
throw new IllegalArgumentException(s"Illegal arguments")
val predictionAndLabels = df.select("prediction", "label").map {
case Row(prediction: Double, label: Double) => (prediction, label)
case other => throw new IllegalArgumentException(s"Illegal arguments")
}
predictionAndLabels.show()
}

我想测试上面函数抛出的异常,但我的测试失败了。

"Testing" should "throw error for datetype" in withSparkSession {
spark => {
// Creating a dataframe 
val someData = Seq(
Row(8, Date.valueOf("2016-09-30")),
Row(9, Date.valueOf("2017-09-30")),
Row(10, Date.valueOf("2018-09-30"))
)
val someSchema = List(
StructField("prediction", IntegerType, true),
StructField("label", DateType , true)
)
val someDF = spark.createDataFrame(
spark.sparkContext.parallelize(someData),
StructType(someSchema)
)
// Testing exception
val caught = intercept[IllegalArgumentException] {
throwError(spark,someDF)
}
assert(caught.getMessage.contains("Illegal arguments"))
}
}

如果将throw new IllegalArgumentException(s"Illegal arguments")移动到映射函数调用之外,则测试通过。

如何测试"throwError"函数引发的异常?

使用sparkDF无法在行级别捕获异常,如果您使用RDD,则可以实现您想要做的事情。

查看此博客:https://www.nicolaferraro.me/2016/02/18/exception-handling-in-apache-spark/

解决问题的方法:

def throwError(spark: SparkSession,df:DataFrame): Unit = {
import spark.implicits._
val countOfRowsBeforeCheck = df.count()
val predictionAndLabels = df.select("prediction", "label").flatMap {
case Row(prediction: Double, label: Double) => Iterator((prediction, label))
case other => Iterator.empty
}
val countOfRowsAfterCheck = predictionAndLabels.count()
if(countOfRowsAfterCheck != countOfRowsBeforeCheck){
throw new IllegalArgumentException(s"Illegal arguments")
}
predictionAndLabels.show()
}

希望得到帮助!!