训练并加载了一个sklearn模型,但无法访问该模型进行预测



我已经训练了一个模型并使用类方法保存它。在进行评估时,我想加载训练模型并在我的测试文件上进行预测;但我甚至不能使用我的加载模型。我检查了加载模型的类型,它是一个方法,但我不确定它是否是正确的类型。

另外,我不确定我的save_model方法是否已经保存了训练模型的拟合模型,我想也许这就是我无法访问加载模型的原因。对不起,如果我不是很清楚,请参考我下面的代码-它会更清楚地了解我的意思。

我搜遍了整个互联网,所有的解决方案似乎都没有帮助。所以任何帮助将非常感激!!!!

保存模型方法

# POS class
def __init__(self):
self.vec = DictVectorizer()
self.model = LinearSVC()
def fit_and_report(self, X, Y):
X = self.vec.fit_transform(X)
self.model.fit(X, Y)
def save_model(self, output_file):
with open(output_file, "wb") as outfile:
pickle.dump(self, outfile)

第一次训练模型时,我做了以下工作。

X, Y = pre_process.load_dataset('my_train_file')
X, Y = pre_process.prepare_data_for_training(X, Y)
my_model = function_in_tagger.POS() #call the POS class from a separate script
my_model.fit_and_report(X, Y) # I assumed the self.model.fit(X, Y) in the method is also saved to the model???
my_model.save_model('my_model_file')

加载并执行预测

X_test, Y_test = pre_process.load_dataset('my_test_file') # it returns X and Y (both are list)
loaded_model = pickle.load(open('my_model_file', 'rb'))
y_predict = loaded_model.predict(X_test)
我得到的错误是
y_predict = loaded_model.predict(Y_test)
AttributeError: 'POS' object has no attribute 'predict'

我也尝试获得训练模型的分数,但得到相同的错误

score = loaded_model.score(X_test, Y_test)
AttributeError: 'POS' object has no attribute 'score'
新编辑

我试图使用DictVectorizer转换我的测试数据,但它给了我一个属性错误('DictVectorizer'对象没有属性'feature_names_')

vec = DictVectorizer()
vec.transform(X_test)

然后我尝试在转换之前拟合测试数据,然后它没有给我错误消息。但据我所知,我不应该拟合我的测试数据,只转换。

当同时使用fit和transform方法时,它仍然给我以下错误,尽管

y_predict = loaded_model.predict(Y_test)
AttributeError: 'POS' object has no attribute 'predict'

您的错误源自您的测试数据。假设您在使用训练好的模型之前对测试数据进行转换。你的模型是在转换数据而不是列表上训练的参见下面基于文本分类的示例;

#加载训练好的模型

with open('Rf Classifier', 'rb') as training_model: Rf_model = pickle.load(training_model))

加载模型后,转换测试数据,在我的例子中是数据帧的形式。在这里,因为我使用多个变量,我使用DataframeMapper库来转换测试数据

test_data=mapper.transform(df_test) #df_test refers to the test data which is a dataframe

之后,使用如下所示的训练模型预测转换/矢量化的数据;

df_test["class"] = Rf_model.predict(test_data) # this section creates a new column for predicted class

最新更新