加载预先训练的 keras 模型,以便在 Google Cloud 上继续训练



我正在尝试加载一个预先训练的Keras模型,以便在谷歌云上继续训练。它在本地工作,只需加载鉴别器和生成器

model = load_model('myPretrainedModel.h5')

但显然这在谷歌云上不起作用,我尝试使用与从谷歌存储桶读取训练数据相同的方法,包括:

fil = "gs://mygcbucket/myPretrainedModel.h5"    
f = BytesIO(file_io.read_file_to_string(fil, binary_mode=True))
return np.load(f)

但是,这似乎不适用于加载模型,我在运行作业时收到以下错误。

值错误:当 allow_pickle=False 时,无法加载包含酸洗数据的文件

添加allow_pickle=True,会引发另一个错误:

OSError:无法解释文件<_io。作为泡菜0x7fdf2bb42620>的字节IO对象

然后,我尝试了我发现的东西,因为有人建议解决类似的问题,因为我知道它暂时从存储桶中在本地重新保存模型(与作业运行的位置有关(,然后加载它,如下所示:

fil = "gs://mygcbucket/myPretrainedModel.h5"  
model_file = file_io.FileIO(fil, mode='rb')
file_stream = file_io.FileIO(model_file, mode='r')
temp_model_location = './temp_model.h5'
temp_model_file = open(temp_model_location, 'wb')
temp_model_file.write(file_stream.read())
temp_model_file.close()
file_stream.close()
model = load_model(temp_model_location)
return model

但是,这会引发以下错误:

类型错误:预期的二进制或 unicode 字符串,已tensorflow.python.lib.io.file_io。文件IO对象

我必须承认,我不确定我需要做什么才能从我的存储桶中实际加载预先训练的 keras 模型,以及在我在 google cloud 的培训工作中的使用。任何帮助都深表感谢。

我建议使用AI平台笔记本来做到这一点。使用此方法下载训练的模型。检查"代码示例"选项卡下的 Python 代码。将模型置于运行笔记本的 VM 中后,可以像在本地一样加载它。这里有一个使用tf.keras.models.load_model的示例。

最新更新