如何在 Tensorflow 输入管道中堆叠通道?



几周前我开始使用tf,现在正在为输入队列而苦苦挣扎。 我想做的是:我有一个文件夹,里面有 477 张临时灰度图像。现在我想例如,将前 3 张图像堆叠在一起 (=> 600,600,3(,以便我得到一个具有 3 个通道的示例。接下来,我想获取第四张图像并将其用作标签(仅 1 个通道 => 600,600,1(。然后我想将两者传递给 tf.train.batch 并创建批处理。

我想我找到了一个解决方案,请参阅下面的代码。但我想知道是否有更时尚的解决方案。

我的实际问题是:队列末尾会发生什么。由于我总是从队列中挑选 4 张图像(3 张用于输入,1 张用于标签(,并且我的队列中有 477 张图像,因此事情没有解决。然后 tf 是否只是再次填满我的队列并继续(因此,如果队列中还剩下 1 张图像,它会获取此图像,再次填满队列并再拍摄 2 张图像以获得所需的 3 张图像?或者如果我想要一个合适的解决方案,我是否需要在我的文件夹中有许多可被 4 整除的图像?

def read_image(filename_queue):
reader = tf.WholeFileReader()
_, value = reader.read(filename_queue)
image = tf.image.decode_png(value, dtype=tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.image.resize_images(image, [600, 600])
return image
def input_pipeline(file_names, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(file_names, num_epochs=num_epochs, shuffle=False)
image1 = read_image(filename_queue)
image2 = read_image(filename_queue)
image3 = read_image(filename_queue)
image = tf.concat([image1, image2, image3,], axis=2)
label = read.image(filename_queue)
# Reshape is necessary, otherwise I get an error..
image = tf.reshape(image, [600, 600, 3])
label = tf.reshape(label, [600, 600, 1])
min_after_dequeue = 200
capacity = min_after_dequeue + 12 * batch_size
image_batch, label_batch = tf.train.batch([image, label],
batch_size=batch_size,
num_threads=12,
capacity=capacity)
return image_batch, label_batch

感谢您的任何帮助!

但我想知道是否有更时尚的解决方案

是的!有一个更好、更快的解决方案。首先,重新设计数据库,因为您希望将 3 个灰色图像合并为 1 个 rgb 图像进行训练;从灰色图像准备一个RGB图像的数据集(这将在训练过程中节省大量时间(。

重新设计检索数据的方式

# retrieve image and corresponding label at the same time 
# here if you set the num_epochs=None, the queue will run continuously; and it will take-care of the data need for training till end
filename_queue = tf.train.string_input_producer([file_names_images_list, corresponding_file_names_label_list], num_epochs=None, shuffle=False)
image = read_image(filename_queue[0])
label = read_image(filename_queue[1])

最新更新