在 pyspark 中保存自定义变压器



当我在Azure Databricks中实现这部分python代码时:

class clustomTransformations(Transformer):
    <code>
custom_transformer = customTransformations()
....
pipeline = Pipeline(stages=[custom_transformer, assembler, scaler, rf])
pipeline_model = pipeline.fit(sample_data)
pipeline_model.save(<your path>)

当我尝试保存管道时,我得到这个:

AttributeError: 'customTransformations' object has no attribute '_to_java'

有什么解决方法吗?

似乎没有简单的解决方法,只能尝试实现_to_java方法,正如此处对 StopWordsRemover 所建议的那样:使用 python 序列化要在 Pyspark ML 管道中使用的自定义转换器

def _to_java(self):
    """
    Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
    Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
    :return: Java object equivalent to this instance.
    """
    dmp = dill.dumps(self)
    pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
    pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
    sc = SparkContext._active_spark_context
    java_class = sc._gateway.jvm.java.lang.String
    java_array = sc._gateway.new_array(java_class, len(pylist))
    for i in xrange(len(pylist)):
        java_array[i] = pylist[i]
    _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
    _java_obj.setStopWords(java_array)
    return _java_obj

相关内容

  • 没有找到相关文章

最新更新