我对不平衡类进行多类分类。我正在使用SGDClassifier(), GradientBoostingClassifier(), RandomForestClassifier(), and LogisticRegression()
和class_weight='balanced'
.比较结果。需要计算精度。我尝试了以下方法来计算加权准确性:
n_samples = len(y_train)
weights_cof = float(n_samples)/(n_classes*np.bincount(data[target_label].as_matrix().astype(int))[1:])
sample_weights = np.ones((n_samples,n_classes)) * weights_cof
print accuracy_score(y_test, y_pred, sample_weight=sample_weights)
y_train
是一个二进制数组。所以sample_weights
的形状与y_train
(n_samples, n_classes
)相同。运行脚本时,收到以下错误:
更新:
Traceback (most recent call last):
File "C:Program Files (x86)JetBrainsPyCharm Community Edition 2016.3.2helperspydevpydevd.py", line 1596, in <module>
globals = debugger.run(setup['file'], None, None, is_module)
File "C:Program Files (x86)JetBrainsPyCharm Community Edition 2016.3.2helperspydevpydevd.py", line 974, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "D:/Destiny/DestinyScripts/MainLocationAware.py", line 424, in <module>
predict_country(featuresDF, score, featuresLabel, country_sample_size, 'gbc')
File "D:/Destiny/DestinyScripts/MainLocationAware.py", line 313, in predict_country
print accuracy_score(y_test, y_pred, sample_weight=sample_weights)
File "C:ProgramDataAnaconda2libsite-packagessklearnmetricsclassification.py", line 183, in accuracy_score
return _weighted_sum(score, sample_weight, normalize)
File "C:ProgramDataAnaconda2libsite-packagessklearnmetricsclassification.py", line 108, in _weighted_sum
return np.average(sample_score, weights=sample_weight)
File "C:ProgramDataAnaconda2libsite-packagesnumpylibfunction_base.py", line 1124, in average
"Axis must be specified when shapes of a and weights "
TypeError: Axis must be specified when shapes of a and weights differ.
该错误似乎表明sample_weights和y_test
/y_pred
数组的形状不同。基本上,该方法创建一个带有y_test == y_pred
的布尔数组,并将其与sample_weights
一起传递给np.average
。该方法中的第一个检查是确保输入的数组和权重是相同的形状,在这种情况下显然它们不是。
更新
您的评论"sample_weights、y_test和y_pred具有相同的形状(n_samples、n_classes)"暴露了问题。根据accuracy_score
的文档,y_pred
和y_true
(在你的情况下y_test
和y_pred
)应该是一维的。您是否可能使用一个热编码标签?如果是这样,则应将它们转换为单值标签,然后再次尝试准确性分数。