我有一个用深度学习人脸识别的学校项目。我需要倒数矩阵来衡量性能指标,比如准确性,精确度。我为此尝试了以下代码。但是,y_test参数给出了一个错误。我怎么解决这个问题?
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(img_array, img_labels,
shuffle=True, stratify=img_labels,
test_size=0.1, random_state=42)
print('Eğitim için eleman sayısı, yükseklik/genişlik ve kanal sayısı: ', x_train.shape)
print('Test için eleman sayısı, yükseklik/genişlik ve kanal sayısı: : ',x_test.shape)
print('Eğitimdeki örnek ve sınıf sayısı :', y_train.shape)
print('Testteki örnek ve sınıf sayısı : ',y_test.shape)
我的代码cm = confusion_matrix(y_test, y_pred)
print(cm)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [55], in <cell line: 1>()
----> 1 cm = confusion_matrix(y_test, y_pred)
2 print(cm)
File ~anaconda3libsite-packagessklearnmetrics_classification.py:307, in confusion_matrix(y_true, y_pred, labels, sample_weight, normalize)
222 def confusion_matrix(
223 y_true, y_pred, *, labels=None, sample_weight=None, normalize=None
224 ):
225 """Compute confusion matrix to evaluate the accuracy of a classification.
226
227 By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}`
(...)
305 (0, 2, 1, 1)
306 """
--> 307 y_type, y_true, y_pred = _check_targets(y_true, y_pred)
308 if y_type not in ("binary", "multiclass"):
309 raise ValueError("%s is not supported" % y_type)
File ~anaconda3libsite-packagessklearnmetrics_classification.py:93, in _check_targets(y_true, y_pred)
90 y_type = {"multiclass"}
92 if len(y_type) > 1:
---> 93 raise ValueError(
94 "Classification metrics can't handle a mix of {0} and {1} targets".format(
95 type_true, type_pred
96 )
97 )
99 # We can't have more than one value on y_type => The set is no more needed
100 y_type = y_type.pop()
ValueError: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets
我知道我不应该在答案中提供这个,但我现在无法添加评论。分类报告期望y_pred和y_test都是一个1-D数组,类标签为整数。TensorFlow模型的预测主要是一个二维数组,每个条目都是一维数组,具有给定行的类概率。因此,您需要对y_pred进行一些预处理。我几周前遇到了类似的东西,我将分享几行代码,可能会有所帮助。
res = np.array(res)
res = res.flatten()
res = np.round(res)
请注意,以上代码是用于二进制分类的。对于多标签分类,您可以使用np.argmax
。