如何在不使用Lambda层的情况下向数据点添加维度?



我正在尝试使用Conv2D层对fashion_mnist数据集进行分类,据我所知,使用以下代码可以轻松完成:

import tensorflow as tf
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = tf.keras.Sequential([
tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1)),
tf.keras.layers.Input(shape=(28,28),batch_size=32),      
tf.keras.layers.Conv2D(4,kernel_size=3),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation="softmax")
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(x=train_images, y=train_labels, validation_data=(test_images, test_labels), epochs=10)

然而,我需要不使用Lambda层. 因此,上述解决方案是不正确的。

所以,我想知道,我如何在不使用Lambda层的情况下对mnist_fashion数据集进行分类?

更新:当我使用下面的代码添加a维度时:

train_images = train_images / 255.0
train_images = tf.expand_dims(train_images,axis=0)
test_images = test_images / 255.0
test_images = tf.expand_dims(test_images,axis=0)

并对相同的模型运行,我得到以下错误:

ValueError: Data cardinality is ambiguous:
x sizes: 1
y sizes: 60000
Make sure all arrays contain the same number of samples.

有几个选项:直接在train_images上使用expand_dims并改变输入形状或使用Reshape层代替Lambda层或完全删除该层并改变input_shapetf.keras.layers.Input(shape=(28,28, 1),batch_size=32)。取决于你想要什么。下面是跨axis=-1expand_dims选项:

train_images = train_images / 255.0
train_images = tf.expand_dims(train_images, axis=-1)
test_images = test_images / 255.0
test_images = tf.expand_dims(test_images, axis=-1)

Input图层改为tf.keras.layers.Input(shape=(28, 28, 1)

Q1答:因为Conv2D层需要3D输入(不包括批量大小),所以类似:(rows, cols, channels)。你的数据形状为(samples, 28, 28)。如果您的通道出现在行和颜色之前,则可以在axis=1上使用expand_dims,从而在使用axis=-1时使用(samples, 1, 28, 28)而不是(samples, 28, 28, 1)。如果是前者,则必须将Conv2D层的data_format参数设置为channels_first。使用axis=0会得到形状(1 samples, 28, 28),这是不正确的,因为第一个维度应该保留为批处理维度。

Q2答:我用shape=(28, 28, 1),因为时尚主义的图像是灰度图像。也就是说,它们有一个通道(我们已经明确定义了)。另一方面,RGB图像有3个通道:红色通道,绿色通道和蓝色通道。

相关内容

  • 没有找到相关文章

最新更新