我在构建图形时构建了三个数据输入管道。
images_pipe_1 = input_images('list1')
images_pipe_2 = input_images('list2')
images_pipe_3 = input_images('list3')
我想根据global_step在图形运行时中选择其中一个,如下所示:
if global_step < 2000:
data input pipeline = images_pipe_1
if global_step >= 2000 and global_step < 5000
data input pipeline = images_pipe_2
if global_step >= 5000
data input pipeline = images_pipe_3
但是在 tensorflow 中,有像 global_step 这样的变量是张量,它们应该由 tf 函数操作,而不是由 python 操作。 我尝试使用tf.cond,但它只能解决两个选项的问题。
images_pipe = tf.cond(tf.greater(global_step, tf.constant(2000, tf.int64)), lambda:images_pipe_2, lambda:images_pipe_1)
在这种情况下,有三个选项。我不知道我该如何解决它。提前感谢您的帮助。
我通过 tf.case 解决它
pipeline = tf.case({tf.greater(global_step, tf.constant(5000,tf.int64)):images-pipe_3, tf.less(global_step, tf.constant(2000,tf.int64):images_pipe_1)}, default=images_pipe_2, exclusive=True)