使用图形集合存储RNN



我经常使用tf.add_to_collection将TensorFlow自动序列化中间结果序列化到检查点。我发现这是从检查站恢复模型时,以后将指针以稍后的张量获取有趣的张量的最方便方法。但是,我意识到RNN状态元组无法轻易地添加到图表集合中。考虑TF 1.3中的以下虚拟示例:

import tensorflow as tf
import numpy as np

in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]
cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.BasicLSTMCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
                                        inputs=in_,
                                        initial_state=cell.zero_state(batch_size, dtype=tf.float32))
tf.add_to_collection('states', last_state)
loss = tf.reduce_mean(in_ - outputs)
loss_s = tf.summary.scalar('loss', loss)
writer = tf.summary.FileWriter('.', tf.get_default_graph())
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    l, s = sess.run([loss, loss_s], feed_dict={in_: np.ones([1, 5, 1])})
    writer.add_summary(s)

这将产生以下警告:

WARNING:tensorflow:Error encountered when serializing states.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'tuple' object has no attribute 'name'

似乎序列化无法处理元组,当然last_state变量是元组。可能是一个人可以通过元组循环,并将每个元素分别添加到集合中,但这似乎太复杂了。有什么更好的处理方式?最后,我想在恢复模型时再次访问last_state,理想情况下,无需访问创建模型的原始代码。

实际上,遍历状态的每个元素都不是太复杂,而且直接实现:

def add_to_collection_rnn_state(name, rnn_state):
    for layer in rnn_state:
        tf.add_to_collection(name, layer.c)
        tf.add_to_collection(name, layer.h)

然后加载它:

def get_collection_rnn_state(name):
    layers = []
    coll = tf.get_collection(name)
    for i in range(0, len(coll), 2):
        state = tf.nn.rnn_cell.LSTMStateTuple(coll[i], coll[i+1])
        layers.append(state)
    return tuple(layers)

请注意,这假定一个集合仅在状态上存储,即为您要存储的每个状态使用不同的集合,例如这样:

add_to_collection_rnn_state('states', last_state)
add_to_collection_rnn_state('init_state', init_state)

编辑

正如评论中正确指出的那样,提出的解决方案仅适用于LSTMCELL(也表示为元组)。可以处理GRU单元或可能的定制单元并混合使用的更通用的解决方案,可以看起来像这样:

import tensorflow as tf
import numpy as np
def add_to_collection_rnn_state(name, rnn_state):
    # store the name of each cell type in a different collection
    coll_of_names = name + '__names__'
    for layer in rnn_state:
        n = layer.__class__.__name__
        tf.add_to_collection(coll_of_names, n)
        try:
            for l in layer:
                tf.add_to_collection(name, l)
        except TypeError:
            # layer is not iterable so just add it directly
            tf.add_to_collection(name, layer)

def get_collection_rnn_state(name):
    layers = []
    coll = tf.get_collection(name)
    coll_of_names = tf.get_collection(name + '__names__')
    idx = 0
    for n in coll_of_names:
        if 'LSTMStateTuple' in n:
            state = tf.nn.rnn_cell.LSTMStateTuple(coll[idx], coll[idx+1])
            idx += 2
        else:  # add more cell types here
            state = coll[idx]
            idx += 1
        layers.append(state)
    return tuple(layers)

in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]
cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.GRUCell(num_units=256)
cell3 = tf.nn.rnn_cell.BasicRNNCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2, cell3])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
                                        inputs=in_,
                                        initial_state=cell.zero_state(batch_size, dtype=tf.float32))
add_to_collection_rnn_state('last_state', last_state)
last_state_r = get_collection_rnn_state('last_state')

比较 last_statelast_state_r表明两者都是相同的(应该是)。请注意,我正在使用不同的集合来存储名称,因为当收藏集中的所有元素均具有相同类型时,TensorFlow只能序列化集合。例如。将字符串与张量混合在同一集合中不起作用。

最新更新