如何在嵌套作用域中使用 tf.get_collection() 作用域筛选



我正在尝试通过在作用域内定义它们并使用tf.get_collection()中的范围过滤来检索一组变量:

with tf.variable_scope('inner'):
    v = tf.get_variable(name='foo', shape=[1])
    ...
    # more variables
    ...
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'inner')
# do stuff with variables

这通常工作正常,但有时我的代码是由已经定义自己的范围的模块调用的,get_collection()不再找到变量:

with tf.variable_scope('outer'):
    with tf.variable_scope('inner'):
        v = tf.get_variable(name='foo', shape=[1])
        ...
        # more variables
        ...

我相信过滤是一个正则表达式,因为我可以通过在我的范围搜索词前面加上 .* 来使 get_collection(( 工作,但这有点黑客。有没有更好的方法来解决这个问题?

当我想训练模型时,我使用 get_collection((,但在此之前,我必须恢复模型数据,如以下代码所示:

with tf.Session() as sess:
    last_check = tf.train.latest_checkpoint(tf_data)
    saver = tf.train.import_meta_graph(last_check + '.meta')
    print (last_check +'.meta')
    saver.restore(sess, last_check)
    ######
    Model_variables = tf.GraphKeys.MODEL_VARIABLES
    Global_Variables = tf.GraphKeys.GLOBAL_VARIABLES
    ######
    all_vars = tf.get_collection(Model_variables)
    # print (all_vars)
    pesos=[]
    for i in all_vars:
        print (str(i) + '  -->  '+ str(i.eval()))

tf.get_collcetion(key, scope="outer/inner"(

最新更新