有没有办法从 CIFAR-10 训练数据集中提取所需的类?



我想做的看起来很简单,但它就是不起作用。我想对每类图像(矩阵(执行某些操作,所以我首先必须从加扰的批次中提取每个图像。

from tensorflow.keras import datasets
import numpy as np
(train_images, train_labels), (test_images, test_labels)= datasets.cifar10.load_data()
print(len(train_images))
print(len(train_images))
train_images[train_labels==6]

这是错误.当然,这是因为图像矩阵的形状(50000,32,32,3(。尽管图像和标签的长度相同,为 50000,但 python 无法以某种方式将矩阵作为 1 个项目进行过滤。将非常欢迎帮助。

50000
50000

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-170-029cc3d4f0a9> in <module>
5 
6 
----> 7 train_images[train_labels==6]
IndexError: boolean index did not match indexed array along dimension 1; dimension is 32 but corresponding boolean dimension is 1

这里的问题是train_labels有形状 (50000,1(,所以当你索引它时,numpy 试图将其用作二维。这是一个简单的修复。

from tensorflow.keras import datasets
import numpy as np
(train_images, train_labels), (test_images, test_labels)= datasets.cifar10.load_data()
print('Images Shape: {}'.format(train_images.shape))
print('Labels Shape: {}'.format(train_labels.shape))
idx = (train_labels == 6).reshape(train_images.shape[0])
print('Index Shape: {}'.format(idx.shape))
filtered_images = train_images[idx]
print('Filtered Images Shape: {}'.format(filtered_images.shape))

最新更新