测试用于文本分类的 SVM 分类器时出错



我已经浏览了sklearn文档,并编写了用于训练SVM分类器以及测试它的代码。但是,在最后一步,我遇到了一个我无法理解的错误。我的代码如下:

rb = open_workbook('subjectcat.xlsx')#C:/Users/5460/Desktop/
wb = copy(rb) #making a copy
sheet = rb.sheet_by_index(0)
#only subjects extracted from excel file     
train_set = () #list
for row_index in range(1,500): #train using 500
    subject = 0
    for col_index in range(1,2):        
        if col_index==1:
            subject = sheet.cell(row_index,col_index).value
            subject = "'" + subject
            train_set = train_set + (subject,)
print 'only subjects'
train = list(train_set)
print len(train_set)
#for t in train_set:
#    print t
vectorizer = TfidfVectorizer(min_df=1) #Tf-idf and CountVector
#extracting features from training data
#corpus = set(train_set)  -- was reducing len to 468
corpus = (train_set)
print len(corpus)
x = vectorizer.fit_transform(corpus)
feature_names = vectorizer.get_feature_names() #use this for toarray() later -- this is to interpret for user
#print feature_names
x_array = x.toarray()
print x_array
print type(x_array)
print len(x_array)
#converting to numpy 2D array
data_array = np.array(x_array)
print type(data_array)
print len(data_array)
print data_array
#only categories extracted from excel file     
cat_set = () #list
for row_index in range(1,500): #train using 500
    subject = 0
    for col_index in range(2,4):        
        if col_index==3:
            category = sheet.cell(row_index,col_index).value
            #in numerical form
            catgory = int(category)
            cat_set = cat_set + (category,)
#for c in cat_set:
#    print c
print 'only categories'
cat_set = list(cat_set)
print len(cat_set)
cat_array = np.array(cat_set)
print cat_array
print type(cat_array)
#################################################################
#data for testing
#only subjects extracted from excel file     
test_set = () #list
for row_index in range(500,575): #train using 500
    subject = 0
    for col_index in range(1,2):        
        if col_index==1:
            subject = sheet.cell(row_index,col_index).value
            subject = "'" + subject
            test_set = test_set + (subject,)
print 'only testing subjects'
test = list(test_set)
print len(test_set)
#extracting features from testing data
test_corpus = (test_set)
print len(test_corpus)
y = vectorizer.fit_transform(test_corpus)
#feature_names = vectorizer.get_feature_names() #use this for toarray() later -- this is to interpret for user
y_array = y.toarray()
#converting to numpy 2D array
test_array = np.array(y_array)
print type(y_array)
print len(y_array)
print y_array
################################################################
def svm_learning(x,y):
    clf = svm.SVC()
    clf.fit(x,y)
    print 'classifier trained'
    return clf #returning classifier
def test_classifier(classifier):
    for t in test_array:
        result = classifier.predict(t)
        print result

classifier = svm_learning(data_array, cat_array)
test_classifier(classifier)

它一直工作到最后,我得到的错误如下:

Traceback (most recent call last):
  File "C:Users5460DesktopCode506_01.py", line 130, in <module>
    test_classifier(classifier)
  File "C:Users5460DesktopCode506_01.py", line 125, in test_classifier
    result = classifier.predict(t)
  File "C:Python27libsite-packagessklearnsvmbase.py", line 466, in predict
    y = super(BaseSVC, self).predict(X)
  File "C:Python27libsite-packagessklearnsvmbase.py", line 282, in predict
    X = self._validate_for_predict(X)
  File "C:Python27libsite-packagessklearnsvmbase.py", line 404, in _validate_for_predict
    (n_features, self.shape_fit_[1]))
ValueError: X.shape[1] = 315 should be equal to 1094, the number of features at training time

我附上结果供裁判参考,如下:

only subjects
499
499
[[ 0.          0.          0.         ...,  0.          0.          0.        ]
 [ 0.          0.          0.         ...,  0.          0.42325613  0.        ]
 [ 0.          0.          0.         ...,  0.          0.42325613  0.        ]
 ..., 
 [ 0.          0.          0.         ...,  0.          0.          0.        ]
 [ 0.          0.          0.         ...,  0.          0.          0.        ]
 [ 0.          0.          0.         ...,  0.          0.          0.        ]]
<type 'numpy.ndarray'>
499
<type 'numpy.ndarray'>
499
[[ 0.          0.          0.         ...,  0.          0.          0.        ]
 [ 0.          0.          0.         ...,  0.          0.42325613  0.        ]
 [ 0.          0.          0.         ...,  0.          0.42325613  0.        ]
 ..., 
 [ 0.          0.          0.         ...,  0.          0.          0.        ]
 [ 0.          0.          0.         ...,  0.          0.          0.        ]
 [ 0.          0.          0.         ...,  0.          0.          0.        ]]
only categories
499
[ 1.  1.  1.  0.  1.  0.  1.  0.  2.  2.  3.  3.  0.  3.  0.  0.  4.  0.
  0.  2.  3.  0.  0.  3.  0.  0.  3.  0.  0.  0.  1.  4.  1.  3.  0.  3.
  0.  3.  2.  3.  0.  0.  3.  2.  4.  0.  3.  2.  3.  2.  3.  3.  0.  0.
  0.  3.  0.  0.  0.  3.  0.  0.  2.  0.  0.  0.  0.  0.  2.  0.  0.  0.
  0.  0.  0.  4.  0.  0.  0.  0.  0.  2.  1.  1.  1.  1.  0.  1.  0.  0.
  0.  3.  0.  0.  0.  3.  3.  2.  0.  3.  0.  3.  3.  4.  1.  3.  3.  0.
  3.  0.  0.  0.  0.  3.  3.  1.  0.  0.  3.  2.  0.  1.  0.  1.  1.  1.
  1.  1.  2.  2.  2.  2.  2.  2.  0.  0.  0.  0.  0.  3.  3.  3.  3.  3.
  0.  3.  3.  0.  3.  0.  3.  3.  0.  0.  0.  3.  3.  1.  3.  3.  3.  0.
  0.  0.  3.  3.  3.  3.  0.  3.  3.  3.  3.  3.  3.  0.  0.  3.  3.  3.
  3.  0.  0.  3.  3.  0.  3.  3.  3.  2.  3.  3.  3.  3.  3.  0.  0.  3.
  3.  3.  3.  0.  3.  3.  3.  0.  3.  3.  4.  0.  3.  0.  0.  2.  3.  0.
  0.  0.  4.  4.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  2.  2.
  4.  2.  2.  0.  0.  0.  2.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  1.  0.  0.  0.  2.  2.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.
  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.  5.]
<type 'numpy.ndarray'>
only testing subjects
75
75
<type 'numpy.ndarray'>
75
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
classifier trained

有关错误的任何帮助将不胜感激。我不确定缺少什么,或者出了什么问题。提前非常感谢!

y = vectorizer.fit_transform(test_corpus)

重新训练矢量化器以学习测试语料库的词汇,这与训练语料库的词汇不同,因此您可以获得不同的功能。在测试集上使用transform而不是fit_transform

相关内容

  • 没有找到相关文章

最新更新