带有 class_weight='auto' 的 SVC 在 scikit-learn 上失败?



我有以下数据集。我把它和SVC分类(它有5个标签)。当我想像这样执行:class_weight='auto'时:

X = tfidf_vect.fit_transform(df['content'].values)
y = df['label'].values

from sklearn import cross_validation
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X,
                                                y)

svm_1 = SVC(kernel='linear', class_weight='auto')
svm_1.fit(X, y)
svm_1_prediction = svm_1.predict(X_test)

然后我得到这个例外:

Traceback (most recent call last):
  File "test.py", line 62, in <module>
    svm_1.fit(X, y)
  File "/usr/local/lib/python2.7/site-packages/sklearn/svm/base.py", line 140, in fit
    y = self._validate_targets(y)
  File "/usr/local/lib/python2.7/site-packages/sklearn/svm/base.py", line 474, in _validate_targets
    self.class_weight_ = compute_class_weight(self.class_weight, cls, y_)
  File "/usr/local/lib/python2.7/site-packages/sklearn/utils/class_weight.py", line 47, in compute_class_weight
    raise ValueError("classes should have valid labels that are in y")
ValueError: classes should have valid labels that are in y

然后对于之前的问题,我尝试了以下方法:

svm_1 = SVC(kernel='linear', class_weight='auto')
svm_1.fit(X, y_encoded)
svm_1_prediction = le.inverse_transform(svm_1.predict(X))

问题是我得到了这个异常:

  File "/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py", line 179, in accuracy_score
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
  File "/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py", line 74, in _check_targets
    check_consistent_length(y_true, y_pred)
  File "/usr/local/lib/python2.7/site-packages/sklearn/utils/validation.py", line 174, in check_consistent_length
    "%s" % str(uniques))
ValueError: Found arrays with inconsistent numbers of samples: [ 858 2598]

有人能帮我理解上面的问题是什么吗?我如何正确使用SVC的class_weight='auto'参数来自动平衡数据?。

更新:

当我执行print(y)时,这是输出:0 5 1 4 2 5 3 4 4 4 5 5 6 4 7 4 8 3 9 5 10 4 11 4 12 1 13 4 14 4 15 5 16 4 17 4 18 5 19 5 20 4 21 4 22 5 23 5 24 3 25 3 26 4 27 5 28 4 29 4 .. 2568 4 2569 4 2570 4 2571 3 2572 4 2573 5 2574 5 2575 5 2576 5 2577 3 2578 4 2579 4 2580 2 2581 4 2582 3 2583 4 2584 5 2585 4 2586 5 2587 4 2588 4 2589 3 2590 5 2591 5 2592 4 2593 4 2594 4 2595 2 2596 2 2597 5

更新

然后我做以下操作:

mask = np.array(test)
print y[np.arange(len(y))[~mask]]

这是输出:

0       5
1       4
2       5
3       4
4       4
5       5
6       4
7       4
8       3
9       5
10      4
11      4
12      1
13      4
14      4
15      5
16      4
17      4
18      5
19      5
20      4
21      4
22      5
23      5
24      3
25      3
26      4
27      5
28      4
29      4
       ..
2568    4
2569    4
2570    4
2571    3
2572    4
2573    5
2574    5
2575    5
2576    5
2577    3
2578    4
2579    4
2580    2
2581    4
2582    3
2583    4
2584    5
2585    4
2586    5
2587    4
2588    4
2589    3
2590    5
2591    5
2592    4
2593    4
2594    4
2595    2
2596    2
2597    5
Name: label, dtype: float64

问题如下:

df.label.unique()
Out[50]: array([  5.,   4.,   3.,   1.,   2.,  nan])

样本代码:

import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import SVC
# replace your own data file_path
df = pd.read_csv('data1.csv', header=0)
df[df.label.isnull()]
Out[52]: 
                               id content  label
900   Daewoo_DWD_M1051__Opinio...       5    NaN
1463  Indesit_IWC_5105_B_it__O...       1    NaN

# drop those two 
df = df[df.label.notnull()]
X = df.content.values
y = df.label.values
transformer = TfidfVectorizer()
X = transformer.fit_transform(X)
estimator = SVC(kernel='linear', class_weight='auto', probability=True)
estimator.fit(X, y)
estimator.predict(X)
Out[54]: array([ 4.,  4.,  4., ...,  2.,  2.,  3.])
estimator.predict_proba(X)
Out[55]: 
array([[ 0.0252,  0.0228,  0.0744,  0.3427,  0.535 ],
       [ 0.002 ,  0.0122,  0.0604,  0.4961,  0.4292],
       [ 0.0036,  0.0204,  0.1238,  0.5681,  0.2841],
       ..., 
       [ 0.1494,  0.3341,  0.1586,  0.1316,  0.2263],
       [ 0.0175,  0.1984,  0.0915,  0.3406,  0.3519],
       [ 0.049 ,  0.0264,  0.2087,  0.3267,  0.3891]])

相关内容

  • 没有找到相关文章

最新更新