GAN对3D图像产生错误,称通道不正确



我收到以下代码错误。我正在与GAN合作。我成功地训练它生成灰度图像,现在我把它改成了3D图像。我在生成器中添加了第三个维度,现在出现了这个错误。有什么想法吗?

Traceback (most recent call last):
File "/media/user/5EB3-54BF/gan3.py", line 84, in <module>
generator = make_generator_model()
File "/media/user/5EB3-54BF/gan3.py", line 63, in make_generator_model
model.add(layers.Conv2DTranspose(256, (5, 5), strides=(1, 1), padding='same', use_bias=False))
File "/home/user/.local/lib/python3.10/site-packages/tensorflow/python/trackable/base.py", line 205, in _method_wrapper
result = method(self, *args, **kwargs)
File "/home/user/.local/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/user/.local/lib/python3.10/site-packages/keras/layers/convolutional/conv2d_transpose.py", line 205, in build
raise ValueError(
ValueError: Inputs should have rank 4. Received input_shape=(None, 14, 14, 3, 512).

def make_generator_model():
model = tf.keras.Sequential() #make 14
model.add(layers.Dense(14*14*3*512, use_bias=False, input_shape=(300,))) #ADD MORE NOISE!!!!!!!
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((14, 14, 3, 512)))
assert model.output_shape == (None, 14, 14, 3, 512)  # Note: None is the batch size
model.add(layers.Conv2DTranspose(256, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 3, 256)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 28, 28, 3, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

#additional layer added here
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 56, 56, 3, 64)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 112, 112, 3)
return model

Conv2DTranspose不是为5D矢量(也称为秩5(定义的。。。使用Conv3DTranspose怎么样?

最新更新