我正在与 regParam 一起运行 spark ml 交叉验证,作为 paramGrid 的一部分,在逻辑回归上运行。
val paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.01))
.build()
val validator = new CrossValidator()
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(3)
这里的估计器将regParam作为参数的一部分包含在内。 保存模型的示例代码:
class MyModelWriter(instance: MyModel[T])extends MLWriter {
override protected def saveImpl(path: String): Unit = {
new DefaultParamsWriter(instance).save(path)
instance.model.save(new Path(path, s"nameOfMofel").toString)
}
}
Mymodel 确实在参数中包含 regParam。
MyModel extends HasRegParam
当我调用 model.save(路径( 时,这是我得到的异常:
java.lang.IllegalArgumentException: 要求失败: ValidatorParams save 要求 estimatorParamMaps 中的所有参数应用于此 ValidatorParams、其 Estimator 或 Evaluator。发现了一个无关紧要的参数:logreg_2fb5fdbe5012__regParam 在斯卡拉。Predef$.require(Predef.scala:224( at org.apache.spark.ml.tuning.ValidatorParams$$anonfun$validateParams$1$$anonfun$apply$1.apply(ValidatorParams.scala:110( at org.apache.spark.ml.tuning.ValidatorParams$$anonfun$validateParams$1$$anonfun$apply$1.apply(ValidatorParams.scala:109( at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59( at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48( at org.apache.spark.ml.tuning.ValidatorParams$$anonfun$validateParams$1.apply(ValidatorParams.scala:109( at org.apache.spark.ml.tuning.ValidatorParams$$anonfun$validateParams$1.apply(ValidatorParams.scala:108( at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33( at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186( at org.apache.spark.ml.tuning.ValidatorParams$.validateParams(ValidatorParams.scala:108( at org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter.(CrossValidator.scala:257( at org.apache.spark.ml.tuning.CrossValidatorModel.write(CrossValidator.scala:242( at org.apache.spark.ml.util.MLWritable$class.save(ReadWrite.scala:157( at org.apache.spark.ml.tuning.CrossValidatorModel.save(CrossValidator.scala:210( at com.criteo.looklike.sink.Sinks$$anonfun$SavePipelineParam1$1.apply(Sinks.scala:111
L105 的 ValidatorParams.scala 代码说
检查以确保所有参数都适用于此估算器。 如果没有,则抛出错误。
据此,它确保估算器映射中的参数,即本例中的regParam存在于估算器/评估器中,在这种情况下,该参数确实存在于上面的Mymodel中。
谁能告诉我的理解是否正确,如果是,是什么原因造成的?谢谢。
我刚刚解决了这个确切的错误。
添加网格时,请尝试传递参数实例;并且在实例化参数时,将其与记录的参数类型匹配,如 https://spark.apache.org/docs/latest/api/scala/...
例如,在RandomForestRegressor
有numTrees: IntParam
.
因此,我按如下方式构建参数网格...
val rf = new RandomForestRegressor()
.{set...()} // (pseudocode)
val numTrees = new IntParam(rf, "numTrees", "Number of trees to train (>= 1) (default = 20)")
// for fun/preference, i make numTrees[Int] increase as does the area of a circle
val numTreesValues = (for (n <- 3 to 20 by 3) yield (math.Pi * math.pow(n, 2)).toInt)
val paramGrid = new ParamGridBuilder()
.addGrid(numTrees, numTreesValues)
.build()
尝试将估计器传递到参数中,然后将参数和值传递到.addGrid
然后我的验证器看起来像这样...
val cv = new CrossValidator()
.setEstimator(rf)
.setEstimatorParamMaps(paramGrid)
.{set...()}