如何制作混淆矩阵来测试python中的Conv神经网络模型



我有一个用于植物种子分类问题的CNN模型。我成功地测试了我的模型,现在我正试图在python中为我的CNN模型创建一个混淆矩阵,但我得到了值错误。我该如何解决这个问题?

x和y值的代码片段:

seed = 31
np.random.seed(seed)    
# Split the train and the validation set
x, x_test, y, y_test= train_test_split(train_tensors, train_targets, test_size = 0.3, train_size =0.7)
x_train, x_cv, y_train, y_cv = train_test_split(x,
y, 
test_size=0.15,
random_state=seed
)    

print(x_train.shape)
print(x_cv.shape)
print(y_train.shape)
print(y_cv.shape)
def one_hot_to_dense(labels_one_hot):
num_labels = labels_one_hot.shape[0]
num_classes = labels_one_hot.shape[1]
labels_dense = np.where(labels_one_hot == 1)[1]      
return labels_dense

我得到错误的代码片段:

validation_predictions = model.predict_classes(x_test)
report=classification_report(one_hot_to_dense(y_test),validation_predictions)
conf_mat= confusion_matrix(one_hot_to_dense(x_test), validation_predictions)
fig, ax = plt.subplots(1,figsize=(10,10))
ax = sns.heatmap(conf_mat, ax=ax, cmap=plt.cm.BuGn, annot=True)
ax.set_xticklabels(abbreviation)
ax.set_yticklabels(abbreviation)
plt.title('Confusion Matrix')
plt.ylabel('True class')
plt.xlabel('Predicted class')
fig.savefig('Confusion matrix.png', dpi=300)
plt.show();
--------------------------------------------------------------------------- ValueError                                Traceback (most recent call last) <ipython-input-21-9c8b8af74165> in <module>
----> 1 conf_mat= confusion_matrix(one_hot_to_dense(x_test), validation_predictions)
2 fig, ax = plt.subplots(1,figsize=(10,10))
3 
4 ax = sns.heatmap(conf_mat, ax=ax, cmap=plt.cm.BuGn, annot=True)
5 ax.set_xticklabels(abbreviation)
~Anaconda3libsite-packagessklearnmetricsclassification.py in confusion_matrix(y_true, y_pred, labels, sample_weight)
251 
252     """
--> 253     y_type, y_true, y_pred = _check_targets(y_true, y_pred)
254     if y_type not in ("binary", "multiclass"):
255         raise ValueError("%s is not supported" % y_type)
~Anaconda3libsite-packagessklearnmetricsclassification.py in
_check_targets(y_true, y_pred)
69     y_pred : array or indicator matrix
70     """
---> 71     check_consistent_length(y_true, y_pred)
72     type_true = type_of_target(y_true)
73     type_pred = type_of_target(y_pred)
~Anaconda3libsite-packagessklearnutilsvalidation.py in check_consistent_length(*arrays)
203     if len(uniques) > 1:
204         raise ValueError("Found input variables with inconsistent numbers of"
--> 205                          " samples: %r" % [int(l) for l in lengths])
206 
207 
ValueError: Found input variables with inconsistent numbers of samples: [1, 658]

您的两个列表:one_hot_to_sense(x_test(和validation_prdictions的长度不相同。您需要有两个长度相同的数组(y_true,y_pred(才能使此函数工作(请参阅文档(

最新更新