如何在 Tensorflow 的图运行时中选择三个数据输入管道之一?



我在构建图形时构建了三个数据输入管道。

images_pipe_1 = input_images('list1')
images_pipe_2 = input_images('list2')
images_pipe_3 = input_images('list3')

我想根据global_step在图形运行时中选择其中一个,如下所示:

if global_step < 2000:
data input pipeline = images_pipe_1
if global_step >= 2000 and global_step < 5000
data input pipeline = images_pipe_2
if global_step >= 5000
data input pipeline = images_pipe_3

但是在 tensorflow 中,有像 global_step 这样的变量是张量,它们应该由 tf 函数操作,而不是由 python 操作。 我尝试使用tf.cond,但它只能解决两个选项的问题。

images_pipe = tf.cond(tf.greater(global_step, tf.constant(2000, tf.int64)), lambda:images_pipe_2, lambda:images_pipe_1)

在这种情况下,有三个选项。我不知道我该如何解决它。提前感谢您的帮助。

我通过 tf.case 解决它

pipeline = tf.case({tf.greater(global_step, tf.constant(5000,tf.int64)):images-pipe_3, tf.less(global_step, tf.constant(2000,tf.int64):images_pipe_1)}, default=images_pipe_2, exclusive=True)

最新更新