如何将"tf.nn.dynamic_rnn"与非 rnn 组件一起使用



我有一个架构,在输入RNN之前使用编码器。编码器输入形状为[batch, height, width, channels],RNN 输入为形状[batch, time, height, width, channels]。我想将编码器的输出直接馈送到 RNN,但这会带来内存问题。我必须一次将batch*time ~= 3*100(通过重塑(图像输入编码器。我知道tf.nn.dynamic_rnn可以利用swap_memory,我也想在编码器中利用它。下面是一些精简的代码:

#image inputs [batch, time, height, width, channels]
inputs = tf.placeholder(tf.float32, [batch, time, in_sh[0], in_sh[1], in_sh[2]])
#This is where the trouble starts
#merge batch and time
inputs = tf.reshape(inputs, [batch*time, in_sh[0], in_sh[1], in_sh[2]])
#build the encoder (and get shape of output)
enc, enc_sh = build_encoder(inputs)
#change back to time format
enc = tf.reshape(enc, [batch, time, enc_sh[0], enc_sh[1], enc_sh[2]])
#build rnn and get initial state (zero_state)
rnn, initial_state = build_rnn()
#use dynamic unrolling
rnn_outputs, rnn_state = tf.nn.dynamic_rnn(
rnn, enc,
initial_state=initial_state,
swap_memory=True,
time_major=False)

我目前使用的方法是先验地在我的所有图像上运行编码器(并保存到光盘(,但我想执行数据集增强(图像(,一旦提取特征,这是不可能的。

对于遇到此问题的任何其他人。我制作了一个从RNNCell派生的包装器,可以满足我的需求。model_fn是一个使用输入构建子图并返回输出张量的函数。不幸的是,必须知道输出形状(至少我无法让它工作(。

class WrapperCell(tf.nn.rnn_cell.RNNCell):
"""A Wrapper for a non recurrent component that feeds into an RNN."""
def __init__(self, model_fn, out_shape, reuse=None):
super(WrapperCell, self).__init__(_reuse=reuse)
self.out_shape = out_shape
self.model_fn = model_fn
@property
def state_size(self):
return tf.TensorShape([1])
@property
def output_size(self):
return tf.TensorShape(self.out_shape)
def call(self, x, h):
mod = self.model_fn(x)
return mod, tf.zeros_like(h)

最新更新