如何创建将被重用的tensorflow操作



在构建/训练模型之前,我有一些数据需要处理。对于这个例子,我想做一个最大池2d。我写了一个简短的函数来实现tensorflow。

import tensorflow
import tensorflow.nn as nn
def _tfMaxPool(arr, pool=(4,4), sess=None):
op = nn.max_pool(arr, (1, 1, pool[0], 1), (1, 1, pool[0], 1 ), padding="VALID")
op = nn.max_pool(op, (1, 1, 1, pool[1]), (1, 1, 1, pool[1]), padding="VALID")
if sess is None:
sess = tensorflow.Session();
return sess.run(op)

问题是,每次都会在我的图中添加节点,这似乎会打乱我的会话。另一种方法是创建模型。

import keras
seq = keras.Sequential([ 
keras.layers.InputLayer((1, 512, 512)), 
keras.layers.MaxPool2D((4, 4), (4, 4), data_format="channels_first")
])
def _tfMaxPool2(arr, pool=(4,4), sess=None):
swapped = arr.swapaxes(0,1)
return seq.predict(swapped).swapaxes(0,1)

这个模型几乎和我想要的一模一样,但我认为我缺少了一些基本的东西。

为什么不重用具有各种输入的图形?在下面的代码中,tfMaxpool只定义了一次。

def _tfMaxPool(arr, pool=(4,4)):
op = nn.max_pool(arr, (1, 1, pool[0], 1), (1, 1, pool[0], 1 ), padding="VALID")
op = nn.max_pool(op, (1, 1, 1, pool[1]), (1, 1, 1, pool[1]), padding="VALID")
return op
input = tf.placeholder()
output = _tfMaxPool(input)
with tf.Session() as sess:
sess.run(output, feed_dict={input:arr1})
sess.run(output, feed_dict={input:arr2})

最新更新