我正在尝试通过在作用域内定义它们并使用tf.get_collection()
中的范围过滤来检索一组变量:
with tf.variable_scope('inner'):
v = tf.get_variable(name='foo', shape=[1])
...
# more variables
...
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'inner')
# do stuff with variables
这通常工作正常,但有时我的代码是由已经定义自己的范围的模块调用的,get_collection()
不再找到变量:
with tf.variable_scope('outer'):
with tf.variable_scope('inner'):
v = tf.get_variable(name='foo', shape=[1])
...
# more variables
...
我相信过滤是一个正则表达式,因为我可以通过在我的范围搜索词前面加上 .*
来使 get_collection(( 工作,但这有点黑客。有没有更好的方法来解决这个问题?
当我想训练模型时,我使用 get_collection((,但在此之前,我必须恢复模型数据,如以下代码所示:
with tf.Session() as sess:
last_check = tf.train.latest_checkpoint(tf_data)
saver = tf.train.import_meta_graph(last_check + '.meta')
print (last_check +'.meta')
saver.restore(sess, last_check)
######
Model_variables = tf.GraphKeys.MODEL_VARIABLES
Global_Variables = tf.GraphKeys.GLOBAL_VARIABLES
######
all_vars = tf.get_collection(Model_variables)
# print (all_vars)
pesos=[]
for i in all_vars:
print (str(i) + ' --> '+ str(i.eval()))
tf.get_collcetion(key, scope="outer/inner"(