如何创建自定义模型类



对不起,如果这是一个简单的问题,或者如果我做错了什么,但我是相当新的keras/tensorflow和python。我正在尝试测试一些基于迁移学习的图像分类模型。为此,我想创建一个函数来构建一个模型,其中我只指定了一些参数,它会自动生成所需的模型。我写了下面的代码:

class modelMaker(tf.keras.Model):
def __init__(self, img_height, img_width, trained='None'):
super(modelMaker, self).__init__()
self.x = tf.keras.Input(shape=(img_height, img_width, 3),name="input_layer")
if (trained == 'None'):
pass
elif (trained == 'ResNet50'):
self.x = tf.keras.applications.resnet50.preprocess_input(self.x)
IMG_SHAPE = (img_height,img_width) + (3,)
base_model = tf.keras.applications.ResNet50(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
base_model.trainable = False
for layer in base_model.layers:
if isinstance(layer, keras.layers.BatchNormalization):
layer.trainable = True
else:
layer.trainable = False
self.x = base_model(self.x)
def call(self, inputs):
return self.x(inputs)

现在我只实现了ResNet50和一个空选项,但我计划添加更多。我尝试使用self.x = LAYER(self.x)添加层的原因是因为模型可以根据未来的参数拥有不同数量的层。

然而,当我试图获得模型的摘要时,使用model.summary(),我得到以下错误:

ValueError:此模型尚未建立。首先通过调用build()或调用fit()和一些数据来构建模型,或者在第一层指定一个input_shape参数来自动构建。

有可能建立这样的模型吗?谢谢你的帮助

model.summary()需要一些关于输入形状和模型(层)结构的信息,以便为您打印它们。因此,您应该在某处将此信息提供给model对象。

如果使用顺序模型或功能API,只需指定input_shape参数以运行model.summary()就足够了。如果您没有指定input_shape,那么您可以调用您的模型或使用model.build来提供此信息。

但是当你使用子类时(就像你所做的那样),这个类的对象没有关于形状和图层的信息,除非你调用call()函数(因为你在call函数中定义了你的图层结构并向它传递了输入)。

有三种方法调用call()函数:

  1. model.fit():在训练时调用
    • 可能不适合你的需要,因为你必须先训练你的模型。
  2. model.build():内部调用
    • 只需传递输入的形状,如model.build((1,128,128,3))
  3. model():直接调用
    • 你需要通过至少一个样本(张量),如model(tf.random.uniform((1,128,128,3))

修改后的代码应该像这样:

class modelMaker(tf.keras.Model):
def __init__(self, img_height, img_width, num_classes=1, trained='dense'):
super(modelMaker, self).__init__()
self.trained = trained
self.IMG_SHAPE = (img_height,img_width) + (3,)
# define common layers
self.flat = tf.keras.layers.Flatten(name="flatten")
self.classify = tf.keras.layers.Dense(num_classes, name="classify")
# define layers for when "trained" != "resnet"
if self.trained == "dense":
self.dense = tf.keras.layers.Dense(128, name="dense128") 

# layers for when "trained" == "resnet"
else:
self.pre_resnet = tf.keras.applications.resnet50.preprocess_input
self.base_model = tf.keras.applications.ResNet50(input_shape=self.IMG_SHAPE, include_top=False, weights='imagenet')
self.base_model.trainable = False
for layer in self.base_model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
layer.trainable = True
else:
layer.trainable = False

def call(self, inputs):
# define your model without resnet 
if self.trained == "dense":
x = self.flat(inputs)
x = self.dense(x)
x = self.classify(x)
return x
# define your model with resnet
else:
x = self.pre_resnet(inputs)
x = self.base_model(x)
x = self.flat(x)
x = self.classify(x)
return x

# add this function to get correct output for model summary
def summary(self):
x = tf.keras.Input(shape=self.IMG_SHAPE, name="input_layer")
model = tf.keras.Model(inputs=[x], outputs=self.call(x))
return model.summary()

model = modelMaker(128, 128, trained="resnet") # create object
model.build((10,128,128,3))                    # build model
model.summary()                                # print summary

输出是:

Model: "model_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_layer (InputLayer)           [(None, 128, 128, 3)]     0         
_________________________________________________________________
tf.__operators__.getitem_6 ( (None, 128, 128, 3)       0         
_________________________________________________________________
tf.nn.bias_add_6 (TFOpLambda (None, 128, 128, 3)       0         
_________________________________________________________________
resnet50 (Functional)        (None, 4, 4, 2048)        23587712  
_________________________________________________________________
flatten (Flatten)            (None, 32768)             0         
_________________________________________________________________
classify (Dense)             (None, 1)                 32769     
=================================================================
Total params: 23,620,481
Trainable params: 32,769
Non-trainable params: 23,587,712
_________________________________________________________________

相关内容

  • 没有找到相关文章

最新更新