我正试图从用ModelCheckpoint回调创建的检查点加载一个tf.keras(v1.15.0(模型,通过删除几个层并添加新层对其进行修改,然后继续在新任务中对其进行训练。我使用tf.distribute.MirroredStrategy((来使用2个gpu进行分布式训练。
strategy = tensorflow.distribute.MirroredStrategy()
with strategy.scope():
# Load pretrained model from checkpoint
model = get_model()
model.load_weights('file_name.hdf5')
# Chop off some layers, add new layers
model = modify_pretrained_model(model)
model.compile(optimizer=opt, loss=loss)
该模型加载良好并进行编译,我可以运行model.summary((,但当我调用model.fit((或model.product((时,我的python堆栈中会出现以下错误:
(0) Failed precondition: Error while reading resource variable compression0_conv0_batchnorm/moving_variance from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/compression0_conv0_batchnorm/moving_variance/N10tensorflow3VarE does not exist.
[[{{node time_distributed_1/model_1/compression0_conv0_batchnorm/FusedBatchNormV3/ReadVariableOp_1}}]]
[[dense_1_1/Sigmoid/_225]]
(1) Failed precondition: Error while reading resource variable compression0_conv0_batchnorm/moving_variance from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/compression0_conv0_batchnorm/moving_variance/N10tensorflow3VarE does not exist.
[[{{node time_distributed_1/model_1/compression0_conv0_batchnorm/FusedBatchNormV3/ReadVariableOp_1}}]]
0 successful operations.
1 derived errors ignored
这个问题似乎解决了这个确切的问题,但没有使用tf.distribute继续培训。
当我在分发范围外实例化一个会话,并在分发范围内设置对它的引用时,代码会崩溃并出现相同的错误。
tf_config = some_custom_config
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()
strategy = tensorflow.distribute.MirroredStrategy()
with strategy.scope():
set_session(sess)
# Load pretrained model from checkpoint
model = get_model()
model.load_weights('file_name.hdf5')
# Chop off some layers, add new layers
model = modify_pretrained_model(model)
model.compile(optimizer=opt, loss=loss)
我花了整整2-3天的时间试图弄清楚这一点。唯一真正起作用的是升级到tf 2.0.0。然后一切都变魔术了。或者,作为最后的手段,我能够使用相同的分发策略在相同的python执行中训练第一个模型,添加和删除额外的层,重新编译并继续训练,但从未能够使用tf 1.15.0中的分发策略重新加载tf.keras ModelCheckpoint。