Spark 自定义估算器,包括持久性



我想为 Spark 开发一个自定义估计器,它也处理伟大管道 API 的持久性。但是正如如何在 PySpark mllib 中滚动自定义估算器所说,

那里(还没有)很多文档。我有一些用 Spark 编写的数据清理代码,并希望将其包装在自定义估算器中。包括一些 na 替换、列删除、过滤和基本特征生成(例如出生日期到年龄)。

  • transformSchema 将使用数据集的案例类ScalaReflection.schemaFor[MyClass].dataType.asInstanceOf[StructType]
  • 适合仅
  • 适合例如平均年龄作为NA.替代品

我仍然不清楚的是:

  • 自定义管道模型中的transform将用于转换新数据的"拟合"估算器。这是对的吗?如果是,我应该如何将拟合值(例如平均年龄)从上方转移到模型中?

  • 如何处理持久性?我在私有 spark 组件中找到了一些通用的loadImpl方法,但不确定如何将我自己的参数(例如平均年龄)传输到用于序列化的MLReader/MLWriter中。

如果您能帮助我使用自定义估算器,那就太好了 - 尤其是持久性部分。

首先,我相信你混合了两种不同的东西:

  • Estimators- 表示可以fit的阶段 -ted。Estimatorfit方法采用Dataset并返回Transformer(模型)。
  • Transformers- 表示可以transform数据的阶段。

当您fitPipeline它时,它会fits所有Estimators并返回PipelineModelPipelineModel可以按顺序transform数据,在模型中的所有Transformers上调用transform

我应该如何转移拟合值

这个问题没有单一的答案。通常,您有两种选择:

  • 将拟合模型的参数作为Transformer的参数传递。
  • 使拟合模型的参数ParamsTransformer

第一种方法通常由内置Transformer使用,但第二种方法应该在一些简单的情况下起作用。

如何处理持久性

  • 如果Transformer仅由其Params定义,则可以扩展DefaultParamsReadable
  • 如果使用更复杂的参数,则应扩展MLWritable并实现对数据有意义的MLWriter。Spark 源代码中有多个示例,展示了如何实现数据和元数据读取/写入。

如果您正在寻找一个易于理解的示例,请查看以下CountVectorizer(Model)

  • EstimatorTransformer共享共同Params
  • 模型
  • 词汇表是一个构造函数参数,模型参数继承自父参数。
  • 元数据(参数)使用DefaultParamsWriter/DefaultParamsReader写入读取。
  • 自定义实现处理数据(词汇)写入和读取。

以下内容使用Scala API,但如果您真的想要,您可以轻松地将其重构为 Python。

首先要做的是:

  • 估计器:实现返回转换器的.fit()
  • 转换器:实现.transform()和操作数据帧
  • 序列化
  • /反序列化:尽最大努力使用内置参数并利用简单的DefaultParamsWritable特征+同伴对象扩展DefaultParamsReadable[T]又名 远离 MLReader/MLWriter,保持代码简单。
  • 参数传递:使用扩展Params的共同特征,并在估算器和模型(又名变压器)之间共享

骨架代码:

// Common Parameters
trait MyCommonParams extends Params {
final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
new StringArrayParam(this, "inputCols", "doc...")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def getInputCols: Array[String] = $(inputCols)
final val meanValues: DoubleArrayParam = 
new DoubleArrayParam(this, "meanValues", "doc...")
// more setters and getters
}
// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel] 
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
override def transformSchema(schema: StructType): StructType = schema // no changes
override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
// your logic here. I can't do all the work for you! ;)
this.setMeanValues(meanValues)
copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
}
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]
// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel] 
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
override def transformSchema(schema: StructType): StructType = schema // no changes
override def transform(dataset: Dataset[_]): DataFrame = {
// your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
// you have access to both inputCols and meanValues here!
}
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]

使用上面的代码,您可以序列化/反序列化包含MyMeanValueStuff阶段的管道。

想看看估算器的一些真正简单的实现吗?最小最大缩放器!(不过我的例子实际上更简单...

相关内容

  • 没有找到相关文章

最新更新