混淆矩阵"Can't handle mix of multiclass and unkown"



我的混淆矩阵显示了一个我无法理解的错误。我想要一个混淆矩阵来显示两个数组y_predy_test之间的混淆。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.metrics import accuracy_score
import pylab as pl
# Code that fills up two numpy arrays, y_test and y_pred with integers
print y_test.shape
print y_pred.shape
cm = confusion_matrix(y_test,y_pred)
plt.matshow(cm)
plt.title('Confusion matrix')
plt.colorbar()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

错误为:

Traceback (most recent call last):
  File "C:work_asaakicodetest_samme_46_classes_unconfused.py", line 159, in <module>
    cm = confusion_matrix(y_test,y_pred)
  File "C:Anacondalibsite-packagessklearnmetricsmetrics.py", line 742, in confusion_matrix
    y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred)
  File "C:Anacondalibsite-packagessklearnmetricsmetrics.py", line 115, in _check_clf_targets
    "".format(type_true, type_pred))
ValueError: Can't handle mix of multiclass and unknown

这个错误是什么意思?当我打印出y_pred.shapey_test.shape时,我得到了相同的形状,(318L)。两个数组的值都在0到29之间。

没关系,我找到了答案,很简单。问题是,在代码(未显示)中,我使用dtype=object将y_pred填充为numpy数组,如下所示:

y_pred = np.array(pickle.load(file("PATH_TO_FILE")), dtype=object)

我去掉了dtype=object部分,它工作得很好。

相关内容

  • 没有找到相关文章

最新更新