从tensorflow数据集中获取错误分类的样本



通过读取图像数据时

train = keras.preprocessing.image_dataset_from_directory(
  './data', 
  labels='inferred', 
  label_mode='binary', 
  validation_split=0.2, 
  subset="training", 
  image_size=(img_height, img_width), 
  batch_size=sz_batch, 
  crop_to_aspect_ratio=True
)

它们存储在tensorflow数据集中。我用同样的程序读取验证数据。为了分析我的NN(一个顺序张量流NN(,我想绘制错误分类的样本(图片(。我可以通过轻松获得预测

pred = model.predict(validation)

但是,我如何才能从tensorflow数据集中获得分类错误的样本?

要获得分类错误的样本,您可以使用以下代码,其中"classes"是原始标签,"pred_labels"是预测标签,它们的索引将存储在索引列表中

wrong_pred=[]
indices=[]
for i in range(len(classes)):
  if classes[i]!=pred_labels[i]:
    indices.append(i)
    wrong_pred.append([classes[i],pred_labels[i]])

请在这里找到完整的代码。非常感谢。

相关内容

最新更新