tensorflow:如何从保存的检查点重新加载并进行增量更改



我有来自预训练模型a的检查点,但没有它的图构建代码。现在我想重新加载模型A,并将子图B添加到其中以获得最终的模型C。这就像C=A+B。然而,由于模型A训练得很好,所以我不想在C中训练它。我只想在C中训练子图B。换句话说,子图A只参与预测阶段(正向传播(而不参与训练阶段(反向传播(,子图B同时参与这两个阶段,这就是我想要训练的目标是在子图B的帮助下,现在模型C将优于模型A

如何做到这一点?我想这可能与saver/restore有关,但我不知道如何让所有的东西协同工作。任何代码片段都将不胜感激。

我使用的是tensorflow 1.12

好吧,您想要做的是一个非常典型的机器学习用例,它可以通过几种方式实现。

如果您只保存了一个检查点,但没有预训练模型的源代码,那么在从检查点加载模型后,您需要:

  • 清除tf.GraphKeys.TRAINABLE_VARIABLES集合以生成变量从ckpt不可通过微调模型进行训练
  • 清除tf.GraphKeys.GLOBAL_VARIABLES集合以生成变量来自未由微调模型重新初始化的ckpt

然后像往常一样构建和训练微调模型。

以下代码使用TensorFlow 1.12 进行测试

import os
import shutil
import tensorflow as tf
import numpy as np
from absl import logging, app, flags

flags.DEFINE_string('pre_train_model_checkpoint_path', '/tmp/pre_train_model', '')
FLAGS = flags.FLAGS

def train_and_save_pre_train_model():
# model: y = w*x + b
x = tf.placeholder(tf.float32, shape=(None), name='x')
y_true = tf.placeholder(tf.float32, shape=(None), name='y')
w = tf.get_variable('w', shape=())
b = tf.get_variable('b', shape=())
y = w*x + b
y_pred = tf.identity(y, 'y_pred')
loss = tf.losses.mean_squared_error(y_true, y_pred)
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss, global_step=tf.train.get_or_create_global_step())
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for _ in range(2000):
x_val = np.random.rand(128)
y_val = x_val * 2.0 - 1.0 # to make w=2.0 and b=-1.0
sess.run(train_op, {x: x_val, y_true: y_val})
saver = tf.train.Saver()
shutil.rmtree(FLAGS.pre_train_model_checkpoint_path)
os.makedirs(FLAGS.pre_train_model_checkpoint_path)
save_path = os.path.join(FLAGS.pre_train_model_checkpoint_path, 'ckpt')
saver.save(sess, save_path, global_step=100)
return sess.run((w, b));

def main(_):
w_val, b_val = train_and_save_pre_train_model()
# to check if pre train model is trained as expected
logging.info('w_val={}, b_val={}'.format(w_val, b_val))
# load pre train model
tf.reset_default_graph()
meta_file = os.path.join(FLAGS.pre_train_model_checkpoint_path, 'ckpt-100.meta') 
saver = tf.train.import_meta_graph(meta_file)
save_path = os.path.join(FLAGS.pre_train_model_checkpoint_path, 'ckpt-100')
sess = tf.Session()
saver.restore(sess, save_path)
x = tf.get_default_graph().get_tensor_by_name('x:0')
y_pred = tf.get_default_graph().get_tensor_by_name('y_pred:0')
w = tf.get_default_graph().get_tensor_by_name('w:0')
b = tf.get_default_graph().get_tensor_by_name('b:0')
# to make variable from ckpt non-trainable by fine tuning model
tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES).clear()
# to make variable from ckpt not re-initialized by fine tuning model
tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES).clear()
# build fine-tuning model: y2 = w2*y + b2
w2 = tf.get_variable('w2', shape=())
b2 = tf.get_variable('b2', shape=())
y2_pred = w2*y_pred + b2
y2_true = tf.placeholder(tf.float32, shape=(None), name='y2')
loss = tf.losses.mean_squared_error(y2_true, y2_pred)
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss, global_step=tf.train.get_or_create_global_step())
sess.run(tf.global_variables_initializer())
for _ in range(2000):
x_val = np.random.rand(128)
y2_val = (x_val * w_val + b_val) * 10.0 + 1.0 # to make w2=10.0 and b2=1.0
sess.run(train_op, {x: x_val, y2_true: y2_val})
w2_val, b2_val = sess.run((w2, b2)) 
logging.info('w2_val={}, b2_val={}'.format(w2_val, b2_val))
# assert w and b is not trained
w_val_after_fine_tuning, b_val_after_fine_tuning = sess.run((w, b))
logging.info('w_val_after_fine_tuning={}, b_val_after_fine_tuning={}'.format(w_val_after_fine_tuning, b_val_after_fine_tuning))
assert(w_val == w_val_after_fine_tuning)
assert(b_val == b_val_after_fine_tuning)
logging.info('all good')

if __name__ == '__main__':
app.run(main)

最新更新