我有一个用于植物种子分类问题的CNN模型。我成功地测试了我的模型,现在我正试图在python中为我的CNN模型创建一个混淆矩阵,但我得到了值错误。我该如何解决这个问题?
x和y值的代码片段:
seed = 31
np.random.seed(seed)
# Split the train and the validation set
x, x_test, y, y_test= train_test_split(train_tensors, train_targets, test_size = 0.3, train_size =0.7)
x_train, x_cv, y_train, y_cv = train_test_split(x,
y,
test_size=0.15,
random_state=seed
)
print(x_train.shape)
print(x_cv.shape)
print(y_train.shape)
print(y_cv.shape)
def one_hot_to_dense(labels_one_hot):
num_labels = labels_one_hot.shape[0]
num_classes = labels_one_hot.shape[1]
labels_dense = np.where(labels_one_hot == 1)[1]
return labels_dense
我得到错误的代码片段:
validation_predictions = model.predict_classes(x_test)
report=classification_report(one_hot_to_dense(y_test),validation_predictions)
conf_mat= confusion_matrix(one_hot_to_dense(x_test), validation_predictions)
fig, ax = plt.subplots(1,figsize=(10,10))
ax = sns.heatmap(conf_mat, ax=ax, cmap=plt.cm.BuGn, annot=True)
ax.set_xticklabels(abbreviation)
ax.set_yticklabels(abbreviation)
plt.title('Confusion Matrix')
plt.ylabel('True class')
plt.xlabel('Predicted class')
fig.savefig('Confusion matrix.png', dpi=300)
plt.show();
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-21-9c8b8af74165> in <module>
----> 1 conf_mat= confusion_matrix(one_hot_to_dense(x_test), validation_predictions)
2 fig, ax = plt.subplots(1,figsize=(10,10))
3
4 ax = sns.heatmap(conf_mat, ax=ax, cmap=plt.cm.BuGn, annot=True)
5 ax.set_xticklabels(abbreviation)
~Anaconda3libsite-packagessklearnmetricsclassification.py in confusion_matrix(y_true, y_pred, labels, sample_weight)
251
252 """
--> 253 y_type, y_true, y_pred = _check_targets(y_true, y_pred)
254 if y_type not in ("binary", "multiclass"):
255 raise ValueError("%s is not supported" % y_type)
~Anaconda3libsite-packagessklearnmetricsclassification.py in
_check_targets(y_true, y_pred)
69 y_pred : array or indicator matrix
70 """
---> 71 check_consistent_length(y_true, y_pred)
72 type_true = type_of_target(y_true)
73 type_pred = type_of_target(y_pred)
~Anaconda3libsite-packagessklearnutilsvalidation.py in check_consistent_length(*arrays)
203 if len(uniques) > 1:
204 raise ValueError("Found input variables with inconsistent numbers of"
--> 205 " samples: %r" % [int(l) for l in lengths])
206
207
ValueError: Found input variables with inconsistent numbers of samples: [1, 658]
您的两个列表:one_hot_to_sense(x_test(和validation_prdictions的长度不相同。您需要有两个长度相同的数组(y_true,y_pred(才能使此函数工作(请参阅文档(