将预先计算的估计值提供给TfidfVectorizer



我训练了scikit-learn的TfidfVectorizer的一个实例,并希望将其持久化到磁盘上。我将IDF矩阵(idf_属性)作为numpy数组保存到磁盘,并将词汇表(vocabulary_)作为JSON对象保存到磁盘(出于安全和其他原因,我避免pickle)。我正在努力做到这一点:

import json
from idf import idf # numpy array with the pre-computed IDFs
from sklearn.feature_extraction.text import TfidfVectorizer
# dirty trick so I can plug my pre-computed IDFs
# necessary because "vectorizer.idf_ = idf" doesn't work,
# it returns "AttributeError: can't set attribute."
class MyVectorizer(TfidfVectorizer):
    TfidfVectorizer.idf_ = idf
# instantiate vectorizer
vectorizer = MyVectorizer(lowercase = False,
                          min_df = 2,
                          norm = 'l2',
                          smooth_idf = True)
# plug vocabulary
vocabulary = json.load(open('vocabulary.json', mode = 'rb'))
vectorizer.vocabulary_ = vocabulary
# test it
vectorizer.transform(['foo bar'])
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/feature_extraction/text.py", line 1314, in transform
    return self._tfidf.transform(X, copy=False)
  File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/feature_extraction/text.py", line 1014, in transform
    check_is_fitted(self, '_idf_diag', 'idf vector is not fitted')
  File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/utils/validation.py", line 627, in check_is_fitted
    raise NotFittedError(msg % {'name': type(estimator).__name__})
sklearn.utils.validation.NotFittedError: idf vector is not fitted

那么,我做错了什么?我没能骗过矢量器对象:不知何故,它知道我在作弊(即,传递预先计算的数据,而不是用实际文本训练它)。我检查了矢量器对象的属性,但找不到像"istrained"、"isfitted"等之类的属性。那么,我该如何欺骗矢量器呢?

好吧,我想我明白了:矢量器实例有一个属性_tfidf,而这个属性又必须有一个_idf_diagtransform方法调用check_is_fitted函数,该函数检查该_idf_diag是否存在。(我错过了它,因为它是一个属性的属性。)所以,我检查了TfidfVectorizer源代码,看看_idf_diag是如何创建的。然后我只是将其添加到_tfidf属性中:

import scipy.sparse as sp
# ... code ...
vectorizer._tfidf._idf_diag = sp.spdiags(idf,
                                         diags = 0,
                                         m = len(idf),
                                         n = len(idf))

现在矢量化工作了。

相关内容

  • 没有找到相关文章

最新更新