我这里有我的代码,其中它循环遍历每个标签或类别,然后从中创建一个模型。但是,我想要创建一个通用模型,该模型将能够接受来自用户输入的新预测。
我知道下面的代码保存了适合循环中最后一个类别的模型。如何解决此问题,以便保存每个类别的模型,以便在加载这些模型时能够预测新文本的标签?
vectorizer = TfidfVectorizer(strip_accents='unicode',
stop_words=stop_words, analyzer='word', ngram_range=(1,3), norm='l2')
vectorizer.fit(train_text)
vectorizer.fit(test_text)
x_train = vectorizer.transform(train_text)
y_train = train.drop(labels = ['question_body'], axis=1)
x_test = vectorizer.transform(test_text)
y_test = test.drop(labels = ['question_body'], axis=1)
# Using pipeline for applying linearSVC and one vs rest classifier
SVC_pipeline = Pipeline([
('clf', OneVsRestClassifier(LinearSVC(), n_jobs=1)),
])
for category in categories:
print('... Processing {}'.format(category))
# train the SVC model using X_dtm & y
SVC_pipeline.fit(x_train, train[category])
# compute the testing accuracy of SVC
svc_prediction = SVC_pipeline.predict(x_test)
print("SVC Prediction:")
print(svc_prediction)
print('Test accuracy is {}'.format(f1_score(test[category], svc_prediction)))
print("n")
#save the model to disk
filename = 'svc_model.sav'
pickle.dump(SVC_pipeline, open(filename, 'wb'))
代码中有多个错误。
-
您在训练和测试中都适合您的
TfidfVectorizer
:vectorizer.fit(train_text) vectorizer.fit(test_text)
这是错误的。调用
fit()
不是增量的。如果调用两次,它不会学习这两个数据。最近一次打电话给fit()
会忘记过去通话中的所有内容。你永远不会在测试数据上适应(学习)某些东西。你需要做的是这样的:
vectorizer.fit(train_text)
-
管道的工作方式不符合您的预期:
# Using pipeline for applying linearSVC and one vs rest classifier SVC_pipeline = Pipeline([ ('clf', OneVsRestClassifier(LinearSVC(), n_jobs=1)), ])
看到您正在
OneVsRestClassifier
内传递LinearSVC
,因此它将自动使用它而无需Pipeline
。Pipeline
不会在这里做任何事情。 当您按顺序希望通过多个模型传递数据时,Pipeline
很有用。像这样:pipe = Pipeline([ ('pca', pca), ('logistic', LogisticRegression()) ])
上述
pipe
将要做的是将数据传递给PCA
, 这将对其进行转换。然后将新数据传递给LogisticRegression
等等。在您的情况下,管道的正确用法可以是:
SVC_pipeline = Pipeline([ ('vectorizer', vectorizer) ('clf', OneVsRestClassifier(LinearSVC(), n_jobs=1)), ])
在此处查看更多示例:
- https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html#examples-using-sklearn-pipeline-pipeline
-
您需要描述更多关于您的
"categories"
的信息。显示一些数据示例。您没有在任何地方使用y_train
和y_test
。类别与"question_body"
不同吗?