实现接口的类集合的重写方法



我正在使用scikit-learn并正在构建管道。一旦管道建立,我使用GridSearchCV来找到最优模型。我正在研究文本数据,所以我在尝试不同的系统。我创建了一个名为Preprocessor的类,它接受一个stemmer和矢量化器类,然后尝试覆盖矢量化器的方法build_analyzer,以合并给定的stemmer。然而,我看到GridSearchCV的set_params只是直接访问实例变量——也就是说,它不会用新的分析器重新实例化矢量器,就像我一直在做的那样:

class Preprocessor(object):
    # hard code the stopwords for now
    stopwords = nltk.corpus.stopwords.words()
    def __init__(self, stemmer_cls, vectorizer_cls):
        self.stemmer = stemmer_cls()
        analyzer = self._build_analyzer(self.stemmer, vectorizer_cls)
        self.vectorizer = vectorizer_cls(stopwords=stopwords,
                                         analyzer=analyzer,
                                         decode_error='ignore')
    def _build_analyzer(self, stemmer, vectorizer_cls):
        # analyzer tokenizes and lowercases
        analyzer = super(vectorizer_cls, self).build_analyzer()
        return lambda doc: (stemmer.stem(w) for w in analyzer(doc))
    def fit(self, **kwargs):
        return self.vectorizer.fit(kwargs)
    def transform(self, **kwargs):
        return self.vectorizer.transform(kwargs)
    def fit_transform(self, **kwargs):
        return self.vectorizer.fit_transform(kwargs)

所以问题是:我如何重写build_analyzer为所有传入的矢量器类?

是的,GridSearchCV直接设置实例字段,然后对更改字段的分类器调用fit。

scikit-learn中的每个分类器都是以这样的方式构建的,即__init__只设置参数字段,并且进一步工作所需的所有依赖对象(例如在您的情况下调用_build_analyzer)仅在fit方法中构造。你必须添加额外的字段来存储vectorizer_cls,然后你必须在fit方法中构造依赖于vectorized_cls和stemmer_cls对象。

类似:

class Preprocessor(object):
    # hard code the stopwords for now
    stopwords = nltk.corpus.stopwords.words()
    def __init__(self, stemmer_cls, vectorizer_cls):
        self.stemmer_cls = stemmer_cls
        self.vectorizer_cls = vectorizer_cls
    def _build_analyzer(self, stemmer, vectorizer_cls):
        # analyzer tokenizes and lowercases
        analyzer = super(vectorizer_cls, self).build_analyzer()
        return lambda doc: (stemmer.stem(w) for w in analyzer(doc))
    def fit(self, **kwargs):
        analyzer = self._build_analyzer(self.stemmer_cls(), vectorizer_cls)
        self.vectorizer_cls = vectorizer_cls(stopwords=stopwords,
                                         analyzer=analyzer,
                                         decode_error='ignore')
        return self.vectorizer_cls.fit(kwargs)
    def transform(self, **kwargs):
        return self.vectorizer_cls.transform(kwargs)
    def fit_transform(self, **kwargs):
        return self.vectorizer_cls.fit_transform(kwargs)

相关内容

  • 没有找到相关文章