从MNIST数据集中选择10张图像



我必须从mnist数据集中选择一批10张图像。每个图像应属于一个不同的类别,即图像0至类别0、图像1至类别1等。

我知道,通过以下方式,我提取了所有的数据集。我想知道如何创建一个由10张图像组成的数据集,每张图像都属于不同的类别

(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()

如果您只需要10个,一个简单的方法是获取前10个不重复的条目:

import pandas as pd
import tensorflow as tf
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
ix = ~pd.Series(Y_train).duplicated()
X_10 = X_train[ix]
Y_10 = Y_train[ix]
Y_10
array([5, 0, 4, 1, 9, 2, 3, 6, 7, 8], dtype=uint8)

我使用了唯一方法,从中对数组的唯一元素及其索引进行排序

(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
train_filter = np.unique(Y_train, return_index=True)
X_train, Y_train = X_train[train_filter[1:]], Y_train[train_filter[1:]]