如何从可训练的张量流变量转换为不可训练的张量流变量

  • 本文关键字:张量流 变量 转换 python tensorflow
  • 更新时间 :
  • 英文 :


我需要恢复一个DNN(VGG16Net(,并使用迁移学习来构建另一个网络。所以在这里我需要将一些过滤器、偏置变量从可训练的张量流变量转换为不可训练的变量(我使用的是原生的 tensorflow 框架,而不是 keras 或任何更高的杠杆包(。

例如,从卷积层 4_1 获取权重我用了 conv4_3_filter=sess.graph.get_tensor_by_name('conv4_3/filter:0') 但是变量"conv4_3_filter"始终是可训练的变量。所以,在这里,我试图找到一种通用的方法,将任何张量流变量从可训练转换为不可训练。我该如何解决这个问题?

我认为不可能

修改tf.Variable trainable属性。但是,有多种解决方法。

假设您有两个变量:

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([2, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 2]), name='v2')

当您使用类及其子类进行优化tf.train.Optimizer默认情况下,它会从集合中获取变量tf.GraphKeys.TRAINABLE_VARIABLES。默认情况下,使用 trainable=True 定义的每个变量都会添加到此集合中。您可以做的是清除此集合,并仅将那些您愿意优化的变量附加到其中。例如,如果我只想优化v1而不优化v2

var_list = tf.trainable_variables()
print(var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>,
#  <tf.Variable 'v2:0' shape=(2, 2) dtype=float32_ref>]
tf.get_default_graph().clear_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
cleared_var_list = tf.trainable_variables()
print(cleared_var_list)
# []
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, var_list[0])
updated_var_list = tf.trainable_variables()
print(updated_var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>]

另一种方法是使用优化器的关键字参数var_list并传递要在训练期间(在执行train_op期间(更新的变量:

optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss, var_list=[v1])

相关内容

  • 没有找到相关文章

最新更新