我正在尝试使用Conv2D
层对fashion_mnist
数据集进行分类,据我所知,使用以下代码可以轻松完成:
import tensorflow as tf
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = tf.keras.Sequential([
tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1)),
tf.keras.layers.Input(shape=(28,28),batch_size=32),
tf.keras.layers.Conv2D(4,kernel_size=3),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation="softmax")
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(x=train_images, y=train_labels, validation_data=(test_images, test_labels), epochs=10)
然而,我需要不使用Lambda层. 因此,上述解决方案是不正确的。
所以,我想知道,我如何在不使用Lambda层的情况下对mnist_fashion
数据集进行分类?
更新:当我使用下面的代码添加a维度时:
train_images = train_images / 255.0
train_images = tf.expand_dims(train_images,axis=0)
test_images = test_images / 255.0
test_images = tf.expand_dims(test_images,axis=0)
并对相同的模型运行,我得到以下错误:
ValueError: Data cardinality is ambiguous:
x sizes: 1
y sizes: 60000
Make sure all arrays contain the same number of samples.
有几个选项:直接在train_images
上使用expand_dims
并改变输入形状或使用Reshape
层代替Lambda
层或完全删除该层并改变input_shapetf.keras.layers.Input(shape=(28,28, 1),batch_size=32)
。取决于你想要什么。下面是跨axis=-1
的expand_dims
选项:
train_images = train_images / 255.0
train_images = tf.expand_dims(train_images, axis=-1)
test_images = test_images / 255.0
test_images = tf.expand_dims(test_images, axis=-1)
将Input
图层改为tf.keras.layers.Input(shape=(28, 28, 1)
Q1答:因为Conv2D
层需要3D输入(不包括批量大小),所以类似:(rows, cols, channels)
。你的数据形状为(samples, 28, 28)
。如果您的通道出现在行和颜色之前,则可以在axis=1
上使用expand_dims
,从而在使用axis=-1
时使用(samples, 1, 28, 28)
而不是(samples, 28, 28, 1)
。如果是前者,则必须将Conv2D
层的data_format
参数设置为channels_first
。使用axis=0
会得到形状(1 samples, 28, 28)
,这是不正确的,因为第一个维度应该保留为批处理维度。
Q2答:我用shape=(28, 28, 1)
,因为时尚主义的图像是灰度图像。也就是说,它们有一个通道(我们已经明确定义了)。另一方面,RGB图像有3个通道:红色通道,绿色通道和蓝色通道。