如何(有效地)在张力流中应用通道完全连接的层



我再次抓住你的头,我可以上班,但确实很慢。希望您可以帮助我优化它。

我正在尝试在TensorFlow中实现卷积自动编码器,并在编码器和解码器之间具有大潜在空间。通常,一个人会用完全连接的层将编码器连接到解码器,但是由于该潜在空间具有很高的维度,因此这样做会创造太多功能,无法在计算上可行。

我在本文中找到了一个很好的解决方案。他们称其为"通道全连接层"。它基本上是每个通道完全连接的层。

我正在研究实施,我可以工作,但是图的产生需要很长时间。到目前为止,这是我的代码:

def _network(self, dataset, isTraining):
        encoded = self._encoder(dataset, isTraining)
        with tf.variable_scope("fully_connected_channel_wise"):
            shape = encoded.get_shape().as_list()
            print(shape)
            channel_wise = tf.TensorArray(dtype=tf.float32, size=(shape[-1]))
            for i in range(shape[-1]):  # last index in shape should be the output channels of the last conv
                channel_wise = channel_wise.write(i, self._linearLayer(encoded[:,:,i], shape[1], shape[1]*4, 
                                  name='Channel-wise' + str(i), isTraining=isTraining))
            channel_wise = channel_wise.concat()
            reshape = tf.reshape(channel_wise, [shape[0], shape[1]*4, shape[-1]])
        reconstructed = self._decoder(reshape, isTraining)
        return reconstructed

那么,关于为什么要花这么长时间的想法?实际上,这是一个范围(2048),但是所有线性层都很小(4x16)。我是用错误的方式来接近这个吗?

谢谢!

您可以在TensorFlow中检查其对该论文的实现。这是他们实现"渠道全连接层"。

def channel_wise_fc_layer(self, input, name): # bottom: (7x7x512)
    _, width, height, n_feat_map = input.get_shape().as_list()
    input_reshape = tf.reshape( input, [-1, width*height, n_feat_map] )
    input_transpose = tf.transpose( input_reshape, [2,0,1] )
    with tf.variable_scope(name):
        W = tf.get_variable(
                "W",
                shape=[n_feat_map,width*height, width*height], # (512,49,49)
                initializer=tf.random_normal_initializer(0., 0.005))
        output = tf.batch_matmul(input_transpose, W)
    output_transpose = tf.transpose(output, [1,2,0])
    output_reshape = tf.reshape( output_transpose, [-1, height, width, n_feat_map] )
    return output_reshape

https://github.com/jazzsaxmafia/inpainting/blob/8c7735ec85393e0a1d40f05c11fa1686f9bd530f/src/src/src/model.py#l60

主要想法是使用tf.batch_matmul函数。

但是,tf.batch_matmul在最新版本的TensorFlow中被删除,您可以使用tf.matmul替换它。

相关内容

  • 没有找到相关文章

最新更新