如何循环浏览张量流中张量的元素



我想使用Tensorflow的tf.image.rot90()创建一个函数batch_rot90(batch_of_images),后者一次仅拍摄一个图像,前者应该一次拍摄一批n张图像(shape = [n,x,y,f](。

自然而然地,应该只通过批处理中的所有图像进行迭代,然后将它们一个一个一个旋转。在numpy中,这看起来像:

def batch_rot90(batch):
  for i in range(batch.shape[0]):
    batch_of_images[i] = rot90(batch[i,:,:,:])
  return batch

在TensorFlow中如何完成?使用tf.while_loop我走了很远:

batch = tf.placeholder(tf.float32, shape=[2, 256, 256, 4])    
def batch_rot90(batch, k, name=''):
      i = tf.constant(0)
      def cond(batch, i):
        return tf.less(i, tf.shape(batch)[0])
      def body(im, i):
        batch[i] = tf.image.rot90(batch[i], k)
        i = tf.add(i, 1)
        return batch, i  
      r = tf.while_loop(cond, body, [batch, i])
      return r

但是不允许使用im[i]的分配,我对r。

返回的内容感到困惑

我意识到使用tf.batch_to_space()可以解决此特定情况的解决方法,但我相信,也应该有某种循环。

更新答案:

x = tf.placeholder(tf.float32, shape=[2, 3])
def cond(batch, output, i):
    return tf.less(i, tf.shape(batch)[0])
def body(batch, output, i):
    output = output.write(i, tf.add(batch[i], 10))
    return batch, output, i + 1
# TensorArray is a data structure that support dynamic writing
output_ta = tf.TensorArray(dtype=tf.float32,
               size=0,
               dynamic_size=True,
               element_shape=(x.get_shape()[1],))
_, output_op, _  = tf.while_loop(cond, body, [x, output_ta, 0])
output_op = output_op.stack()
with tf.Session() as sess:
    print(sess.run(output_op, feed_dict={x: [[1, 2, 3], [0, 0, 0]]}))

我认为您应该考虑使用tf.scatter_update在批处理中更新一个图像,而不是使用batch[i] = ...。有关详细信息,请参阅此链接。在您的情况下,我建议将第一行的主体更改为:

tf.scatter_update(batch, i, tf.image.rot90(batch[i], k))

TF中有一个地图函数,可以工作:

def batch_rot90(batch, k, name=''):
  fun = lambda x: tf.images.rot90(x, k = 1)
  return = tf.map_fn(fun, batch)

最新更新