我正在尝试在TensorFlow中创建混淆矩阵,但我得到了一个
类型错误:图像数据无法转换为浮点型。
图像被准确预测,但现在我想使用 matplotlib 显示混淆矩阵。我尝试转换为 np.array((,但错误仍然相同。
我正在遵循来自scikit-learn的混淆矩阵的官方文档。https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
if result[0][0]>0.85:
predictions.append(result[0][0])
elif result[0][1]>0.85:
predictions.append(result[0][1])
elif result[0][2]>0.85:
predictions.append(result[0][2])
elif result[0][3]>0.85:
predictions.append(result[0][3])
elif result[0][4]>0.85:
predictions.append(result[0][4])
elif result[0][5]>0.85:
predictions.append(result[0][5])
class_names = ['Up', 'Down', 'Left', 'Right', 'Forward', 'Backward']
# label_list contains the filename e.g. hand1.jpg, hand2.jpg....
# Compute confusion matrix
cnf_matrix = tf.confusion_matrix(label_list,predictions,num_classes=6)
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
plt.figure()
# ERROR HERE
plot_confusion_matrix(cnf_matrix, classes=class_names,title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,title='Normalized confusion matrix')
plt.show()
我还没有在我的电脑上测试过它。您的描述对我来说有点模棱两可(错误行等(,但是您的代码和您链接的文档的主要区别是confusion_matrix()
.只需尝试使用sckit-learn的confusion_matrix()
而不是confusion_matrix()
的张量流(在链接中,使用前者(。在我看来,这是你能走的最简单的方法。
编辑:做出这样的预测:
for i in range(6):
if result[0][i] > 0.85:
predictions.append(i)
continue
那么你的预测将不是连续的。在这里,您的预测应该是整数,因为您正在预测类标签。