Tensorflow:用tf替换/馈送图的占位符.变量?



我有一个模型M1,其数据输入是占位符M1.input,并且训练了权重。 我的目标是构建一个新的模型M2该模型以tf.Variable的形式从输入w计算M1的输出o(及其训练的权重((而不是将实际值提供给M1.input(。换句话说,我将训练好的模型M1用作黑盒函数来构建新的模型o = M1(w)(在我的新模型中,w是要学习的,M1的权重被固定为常量(。问题是M1只接受输入M1.input,我们需要通过它输入实际值,而不是 tf。变量如w.

作为构建M2的天真解决方案,我可以在M2内手动构建M1,然后用预先训练的值初始化M1的权重,并使它们无法在M2内训练。但是,在实践中,M1很复杂,我不想在M2内再次手动构建M1。我正在寻找一个更优雅的解决方案,例如解决方法或直接解决方案,将M1的输入占位符M1.input替换为 tf。可变w.

谢谢你的时间。

这是可能的。怎么样:

import tensorflow as tf

def M1(input, reuse=False):
with tf.variable_scope('model_1', reuse=reuse):
param = tf.get_variable('param', [1])
o = input + param
return o

w = tf.get_variable('some_w', [1])
plhdr = tf.placeholder_with_default(w, [1])
output_m1 = M1(plhdr)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(w.assign([42]))
print(sess.run(output_m1, {plhdr: [0]}))  # direct from placeholder
print(sess.run(output_m1))                # direct from variable

因此,当feed_dict具有占位符的值时,将使用此值。否则,使用变量"w"的回退选项处于活动状态。

相关内容

  • 没有找到相关文章

最新更新