对不起,如果这是一个简单的问题,或者如果我做错了什么,但我是相当新的keras/tensorflow和python。我正在尝试测试一些基于迁移学习的图像分类模型。为此,我想创建一个函数来构建一个模型,其中我只指定了一些参数,它会自动生成所需的模型。我写了下面的代码:
class modelMaker(tf.keras.Model):
def __init__(self, img_height, img_width, trained='None'):
super(modelMaker, self).__init__()
self.x = tf.keras.Input(shape=(img_height, img_width, 3),name="input_layer")
if (trained == 'None'):
pass
elif (trained == 'ResNet50'):
self.x = tf.keras.applications.resnet50.preprocess_input(self.x)
IMG_SHAPE = (img_height,img_width) + (3,)
base_model = tf.keras.applications.ResNet50(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
base_model.trainable = False
for layer in base_model.layers:
if isinstance(layer, keras.layers.BatchNormalization):
layer.trainable = True
else:
layer.trainable = False
self.x = base_model(self.x)
def call(self, inputs):
return self.x(inputs)
现在我只实现了ResNet50和一个空选项,但我计划添加更多。我尝试使用self.x = LAYER(self.x)
添加层的原因是因为模型可以根据未来的参数拥有不同数量的层。
然而,当我试图获得模型的摘要时,使用model.summary()
,我得到以下错误:
ValueError:此模型尚未建立。首先通过调用
build()
或调用fit()
和一些数据来构建模型,或者在第一层指定一个input_shape
参数来自动构建。
有可能建立这样的模型吗?谢谢你的帮助
model.summary()
需要一些关于输入形状和模型(层)结构的信息,以便为您打印它们。因此,您应该在某处将此信息提供给model
对象。
如果使用顺序模型或功能API,只需指定input_shape
参数以运行model.summary()
就足够了。如果您没有指定input_shape
,那么您可以调用您的模型或使用model.build
来提供此信息。
但是当你使用子类时(就像你所做的那样),这个类的对象没有关于形状和图层的信息,除非你调用call()
函数(因为你在call
函数中定义了你的图层结构并向它传递了输入)。
有三种方法调用call()
函数:
model.fit()
:在训练时调用- 可能不适合你的需要,因为你必须先训练你的模型。
model.build()
:内部调用- 只需传递输入的形状,如
model.build((1,128,128,3))
- 只需传递输入的形状,如
model()
:直接调用- 你需要通过至少一个样本(张量),如
model(tf.random.uniform((1,128,128,3))
- 你需要通过至少一个样本(张量),如
修改后的代码应该像这样:
class modelMaker(tf.keras.Model):
def __init__(self, img_height, img_width, num_classes=1, trained='dense'):
super(modelMaker, self).__init__()
self.trained = trained
self.IMG_SHAPE = (img_height,img_width) + (3,)
# define common layers
self.flat = tf.keras.layers.Flatten(name="flatten")
self.classify = tf.keras.layers.Dense(num_classes, name="classify")
# define layers for when "trained" != "resnet"
if self.trained == "dense":
self.dense = tf.keras.layers.Dense(128, name="dense128")
# layers for when "trained" == "resnet"
else:
self.pre_resnet = tf.keras.applications.resnet50.preprocess_input
self.base_model = tf.keras.applications.ResNet50(input_shape=self.IMG_SHAPE, include_top=False, weights='imagenet')
self.base_model.trainable = False
for layer in self.base_model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
layer.trainable = True
else:
layer.trainable = False
def call(self, inputs):
# define your model without resnet
if self.trained == "dense":
x = self.flat(inputs)
x = self.dense(x)
x = self.classify(x)
return x
# define your model with resnet
else:
x = self.pre_resnet(inputs)
x = self.base_model(x)
x = self.flat(x)
x = self.classify(x)
return x
# add this function to get correct output for model summary
def summary(self):
x = tf.keras.Input(shape=self.IMG_SHAPE, name="input_layer")
model = tf.keras.Model(inputs=[x], outputs=self.call(x))
return model.summary()
model = modelMaker(128, 128, trained="resnet") # create object
model.build((10,128,128,3)) # build model
model.summary() # print summary
输出是:
Model: "model_9"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None, 128, 128, 3)] 0
_________________________________________________________________
tf.__operators__.getitem_6 ( (None, 128, 128, 3) 0
_________________________________________________________________
tf.nn.bias_add_6 (TFOpLambda (None, 128, 128, 3) 0
_________________________________________________________________
resnet50 (Functional) (None, 4, 4, 2048) 23587712
_________________________________________________________________
flatten (Flatten) (None, 32768) 0
_________________________________________________________________
classify (Dense) (None, 1) 32769
=================================================================
Total params: 23,620,481
Trainable params: 32,769
Non-trainable params: 23,587,712
_________________________________________________________________