scikit-learn SGD文档分类器:只使用重要的特征



我有一个包含文档及其描述的文本文件。我使用scikit-learn中提供的SGD分类器来获得两个独立的文档类。我使用以下代码训练了我的模型:

fo = open('training_data.txt','rb')
all_classes = np.array([0,1])
for i,line in enumerate(generate_in_chunks(fo,1000)):
    x = [member.split('^')[2] for member in line if member!="n"]
    y = [member.split('^')[1] for member in line if member!="n"]
    vectorizer = HashingVectorizer(decode_error='ignore', n_features=2 ** 18,non_negative=True)
    x_train =  vectorizer.transform(x)
    y_train = np.asarray(y,dtype=int)
    clf = SGDClassifier(loss='log',penalty='l2',shuffle=True)
    clf.partial_fit(x_train, y_train,classes=all_classes)

现在我在测试数据集上使用这个clf对象。这里我想使用在教程中提到的变换:http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html sklearn.linear_model.SGDClassifier

代码:

fo = open('test_data.txt','rb')
prob_comp = open('pred_prob_actual.txt','wb')
for i,line in enumerate(generate_in_chunks(fo,21000)):
    x = [member.split('^')[2] for member in line if member!="n"]
    y = [member.split('^')[1] for member in line if member!="n"]
    vectorizer = HashingVectorizer(decode_error='ignore', n_features=2 ** 18,non_negative=True)
    x_test =  vectorizer.transform(x)
    y_test = np.asarray(y,dtype=int)
    clf.predict(clf.transform(x_test))
错误:

回溯(最近一次调用):

文件"test.py",第106行clf.predict (clf.transform (x_test))文件"/opt/anaconda2.2/lib/python2.7/site-packages/sklearn/linear_model/base.py",第223行scores = self.decision_function(X)文件"/opt/anaconda2.2/lib/python2.7/site-packages/sklearn/linear_model/base.py",第204行,在decision_function .py中% (X.shape[1], n_features))

ValueError: X每个样本有78个特征;期待206

所以基本上,虽然它已经确定了重要的特征,但它不能在预测测试数据时使用它们。

任何关于我如何在测试数据上使用转换方法的建议将被广泛赞赏。我想只使用重要的特性,并寻找可以帮助做到这一点的方法,只是为了让它更清楚。谢谢。

将最后一行改为:

clf.predict(x_test.toarray())

您正在使用HashingVectorizer转换数据集,但这还不够。您需要应用toarray()以获得预测所基于的特征向量矩阵。

虽然,为了可读性和"更好的"(在我看来)代码结构,我建议你调整你的代码为:

x_train =  vectorizer.fit_transform(x)
...
x_test = vectorizer.transform(x).toarray()
y_test = np.asarray(y,dtype=int)
result = clf.predict(x_test)
print result

相关内容

  • 没有找到相关文章

最新更新