不能在训练我的模型后保存keras模型



下面是来自keras的代码示例。关于一个没有注意力的视觉变压器,在这里。我想在完成训练后将其保存在tensorflow/keras中,但它会产生错误。模型代码的最后一部分在

下面给出
# Get the total number of steps for training.
total_steps = int((len(x_train) / config.batch_size) * config.epochs)
# Calculate the number of steps for warmup.
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
# Initialize the warmupcosine schedule.
scheduled_lrs = WarmUpCosine(
lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps,
)
# Get the optimizer.
optimizer = tfa.optimizers.AdamW(
learning_rate=scheduled_lrs, weight_decay=config.weight_decay
)
# Compile and pretrain the model.
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
# Train the model
history = model.fit(
train_ds,
epochs=config.epochs,
validation_data=val_ds,
callbacks=[
keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5, mode="auto",)
],
)
# Evaluate the model with the test dataset.
print("TESTING")
loss, acc_top1, acc_top5 = model.evaluate(test_ds)
print(f"Loss: {loss:0.2f}")
print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")

我尝试了以下两种方法来保存我的模型

model.save('/content/drive/MyDrive/VIT-SHIFT') 

history.save('/content/drive/MyDrive/VIT-SHIFT')

但是它说模型和历史没有定义。完整的代码可在这个colab笔记本

正如您在关于相同代码的其他问题中提到的,在这里,您可能需要首先实现call方法。现在,保存和重新加载模型应该很简单,但我遇到了一个关于代码中使用的layers.RandomCrop增强层的问题。

def get_augmentation_model():
"""Build the data augmentation model."""
data_augmentation = keras.Sequential(
[
layers.Resizing(..),
layers.RandomCrop(...), # <--- possible causes
layers.RandomFlip("horizontal"),
layers.Rescaling(1 / 255.0),
]
)
return data_augmentation

调用引用已删除变量的函数。这可能意味着创建了函数局部变量,而没有在程序的其他地方引用。这通常是一个错误;考虑在第一次调用时将变量存储在对象属性中。

虽然我看了源代码,但没有感觉到任何可疑的错误。所以,为了结束这个,我实现了自定义随机裁剪图层,理论上做同样的事情。然而,如果你坚持使用内置的layers.RandomCrop,那么我建议在GitHub上打开一个关于这个问题的票。


要使整个代码端到端训练和保存+重新加载,更改代码如下。首先,在ShiftViTModel类中实现call方法。

class ShiftViTModel(keras.Model):
def __init__(self):
super().__init__(**kwargs)
...
# init params
def get_config(self):
...
return config
def _calculate_loss(self, data, training=False):
...
return total_loss, labels, logits
def train_step(self, inputs):
...
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
...
return {m.name: m.result() for m in self.metrics}

# implement the call function
def call(self, images):
augmented_images = self.data_augmentation(images)
x = self.patch_projection(augmented_images)
logits = self.global_avg_pool(x)
return logits

接下来,实现一个自定义随机裁剪层,并按如下方式使用它而不是layers.RandomCrop

class CustomRandomCrop(layers.Layer):
def __init__(self, size, **kwargs):
super().__init__(**kwargs)
self.size = size

def call(self, inputs, training=True):
if training:
outputs = tf.map_fn(lambda img: tf.image.random_crop(img,
self.size), inputs)
else:
outputs = tf.image.resize(inputs, self.size[:-1])
return outputs

def get_config(self):
config = super().get_config()
config.update(
{
'size': self.size,
}
)
return config
def get_augmentation_model():
"""Build the data augmentation model."""
data_augmentation = keras.Sequential(
[
layers.Resizing(...),
CustomRandomCrop(...), # custom random crop
layers.RandomFlip("horizontal"),
layers.Rescaling(1 / 255.0),
]
)
return data_augmentation

经过这些修改后,我们现在可以做以下操作而不会出现任何错误。

x,y = next(iter(train_ds))
print(x.shape, y.shape)
model(x.shape) 
# OK
history = model.fit(
x=x, y=y,
epochs=1
) 
# OK
model.evaluate(x, y)
# OK

保存和重新加载工作。

model.save('/content/VIT-SHIFT')
# OK
new_model = tf.keras.models.load_model('/content/VIT-SHIFT')
# OK
np.testing.assert_allclose(
model.predict(x), new_model.predict(x)
)
# OK

这里是完整的Code-in-Colab。请保存文件,有一天我可能会把文件从驱动器上删除。

最后,仅供参考,history.save('...'),您不能在keras中这样做。为了保存tensorflow/keras模型,请参考本文档。history对象将只返回跟踪的指标和训练期间的损失。例如

history = model.fit(
x=x, y=y,
epochs=1
)
history.history
{'accuracy': [0.09765625],
'loss': [6.204378128051758],
'top-5-accuracy': [0.36328125]}

您可以保存上述字典中的训练日志,或者在训练时更好地使用CSVLogger

相关内容

  • 没有找到相关文章

最新更新