我经常使用tf.add_to_collection
将TensorFlow自动序列化中间结果序列化到检查点。我发现这是从检查站恢复模型时,以后将指针以稍后的张量获取有趣的张量的最方便方法。但是,我意识到RNN状态元组无法轻易地添加到图表集合中。考虑TF 1.3中的以下虚拟示例:
import tensorflow as tf
import numpy as np
in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]
cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.BasicLSTMCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
inputs=in_,
initial_state=cell.zero_state(batch_size, dtype=tf.float32))
tf.add_to_collection('states', last_state)
loss = tf.reduce_mean(in_ - outputs)
loss_s = tf.summary.scalar('loss', loss)
writer = tf.summary.FileWriter('.', tf.get_default_graph())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
l, s = sess.run([loss, loss_s], feed_dict={in_: np.ones([1, 5, 1])})
writer.add_summary(s)
这将产生以下警告:
WARNING:tensorflow:Error encountered when serializing states.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'tuple' object has no attribute 'name'
似乎序列化无法处理元组,当然last_state
变量是元组。可能是一个人可以通过元组循环,并将每个元素分别添加到集合中,但这似乎太复杂了。有什么更好的处理方式?最后,我想在恢复模型时再次访问last_state
,理想情况下,无需访问创建模型的原始代码。
实际上,遍历状态的每个元素都不是太复杂,而且直接实现:
def add_to_collection_rnn_state(name, rnn_state):
for layer in rnn_state:
tf.add_to_collection(name, layer.c)
tf.add_to_collection(name, layer.h)
然后加载它:
def get_collection_rnn_state(name):
layers = []
coll = tf.get_collection(name)
for i in range(0, len(coll), 2):
state = tf.nn.rnn_cell.LSTMStateTuple(coll[i], coll[i+1])
layers.append(state)
return tuple(layers)
请注意,这假定一个集合仅在状态上存储,即为您要存储的每个状态使用不同的集合,例如这样:
add_to_collection_rnn_state('states', last_state)
add_to_collection_rnn_state('init_state', init_state)
编辑
正如评论中正确指出的那样,提出的解决方案仅适用于LSTMCELL(也表示为元组)。可以处理GRU单元或可能的定制单元并混合使用的更通用的解决方案,可以看起来像这样:
import tensorflow as tf
import numpy as np
def add_to_collection_rnn_state(name, rnn_state):
# store the name of each cell type in a different collection
coll_of_names = name + '__names__'
for layer in rnn_state:
n = layer.__class__.__name__
tf.add_to_collection(coll_of_names, n)
try:
for l in layer:
tf.add_to_collection(name, l)
except TypeError:
# layer is not iterable so just add it directly
tf.add_to_collection(name, layer)
def get_collection_rnn_state(name):
layers = []
coll = tf.get_collection(name)
coll_of_names = tf.get_collection(name + '__names__')
idx = 0
for n in coll_of_names:
if 'LSTMStateTuple' in n:
state = tf.nn.rnn_cell.LSTMStateTuple(coll[idx], coll[idx+1])
idx += 2
else: # add more cell types here
state = coll[idx]
idx += 1
layers.append(state)
return tuple(layers)
in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]
cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.GRUCell(num_units=256)
cell3 = tf.nn.rnn_cell.BasicRNNCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2, cell3])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
inputs=in_,
initial_state=cell.zero_state(batch_size, dtype=tf.float32))
add_to_collection_rnn_state('last_state', last_state)
last_state_r = get_collection_rnn_state('last_state')
比较 last_state
和 last_state_r
表明两者都是相同的(应该是)。请注意,我正在使用不同的集合来存储名称,因为当收藏集中的所有元素均具有相同类型时,TensorFlow只能序列化集合。例如。将字符串与张量混合在同一集合中不起作用。