我正在尝试使用逻辑回归将邮件分类为"垃圾邮件"或"火腿"。我使用了来自 http://archive.ics.uci.edu/的数据源ml/datasets/SMS+Spam+Collection我发现TFIDF是从文本中获取特征的正确方法,所以我使用了scikit learn的TfidfVectorizer,这是我的代码:
msg_df = pd.read_csv('data/sms', delimiter='t', header = None)
X_train_data, X_test_data, y_train_data, y_test_data = train_test_split(msg_df[1],msg_df[0])
sms_vectorizer = TfidfVectorizer()
X_train_vector = sms_vectorizer.fit_transform(X_train_data)
X_test_vector = sms_vectorizer.transform(X_test_data)
classifier = LogisticRegression()
classifier.fit(X_train_vector, y_train_data)
sms_predictions = classifier.predict(X_test_vector)
print sms_predictions
for i, prediction in enumerate(sms_predictions[:5]):
print 'Prediction: %s. Message: %s' % (prediction, X_test_data[i])
当我运行代码时,发生以下错误:
KeyError Traceback (most recent call last)
<ipython-input-19-b5f57158f320> in <module>()
4 print sms_predictions
5 for i, prediction in enumerate(sms_predictions[:5]):
----> 6 print 'Prediction: %s. Message: %s' % (prediction, X_test_data[i])
/usr/local/lib/python2.7/dist-packages/pandas/core/series.pyc in __getitem__(self, key)
555 def __getitem__(self, key):
556 try:
--> 557 result = self.index.get_value(self, key)
558
559 if not np.isscalar(result):
/usr/local/lib/python2.7/dist-packages/pandas/core/index.pyc in get_value(self, series, key)
1788
1789 try:
-> 1790 return self._engine.get_value(s, k)
1791 except KeyError as e1:
1792 if len(self) > 0 and self.inferred_type in ['integer','boolean']:
/usr/local/lib/python2.7/dist-packages/pandas/index.so in pandas.index.IndexEngine.get_value (pandas/index.c:3204)()
/usr/local/lib/python2.7/dist-packages/pandas/index.so in pandas.index.IndexEngine.get_value (pandas/index.c:2903)()
/usr/local/lib/python2.7/dist-packages/pandas/index.so in pandas.index.IndexEngine.get_loc (pandas/index.c:3843)()
/usr/local/lib/python2.7/dist-packages/pandas/hashtable.so in pandas.hashtable.Int64HashTable.get_item (pandas/hashtable.c:6525)()
/usr/local/lib/python2.7/dist-packages/pandas/hashtable.so in pandas.hashtable.Int64HashTable.get_item (pandas/hashtable.c:6463)()
KeyError: 0
从这个开始:
print sms_predictions
for i, prediction in enumerate(sms_predictions[:5]):
print i, prediction
#print 'Prediction: %s. Message: %s' % (prediction, X_test_data[i])