TensorFlow tfrecords: tostring() 更改图像的维度



我已经建立了一个模型来训练TensorFlow中的卷积自动编码器。我按照从 TF 文档中读取数据的说明读取了我自己的大小为 233 x 233 x 3 的图像。这是我根据这些指令改编的 convert_to(( 函数:

def convert_to(images, name):
"""Converts a dataset to tfrecords."""
num_examples = images.shape[0]
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(FLAGS.tmp_dir, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
print(images[index].size)
image_raw = images[index].tostring()
print(len(image_raw))
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()

当我在 for 循环开始时打印图像的大小时,大小是162867,但是当我在 .tostring(( 行之后打印时,大小是1302936。这会导致问题,因为模型认为我的输入是应有的 8 倍。将示例中的"image_raw"条目更改为 _int64_feature(image_raw( 还是更改将其转换为字符串的方式更好?

或者,问题可能出在我的 read_and_decode(( 函数中,例如字符串未正确解码或示例未解析......?

def read_and_decode(self, filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string)
})
# Convert from a scalar string tensor to a uint8 tensor
image = tf.decode_raw(features['image_raw'], tf.uint8)
# Reshape into a 233 x 233 x 3 image and apply distortions
image = tf.reshape(image, (self.input_rows, self.input_cols, self.num_filters))
image = data_sets.normalize(image)
image = data_sets.apply_augmentation(image)
return image

谢谢!

我可能对你的问题有一些答案。

首先,使用.tostring()方法后,您的图像长 8 倍是完全正常的。后者以字节为单位转换数组。它的名字很糟糕,因为在python 3中,字节与字符串不同(但它们在python 2中是相同的(。默认情况下,我猜您的图像是在 int64 中定义的,因此每个元素都将使用 8 个字节(或 64 位(进行编码。在您的示例中,图像的 162867 像素以 1302936 字节编码...

关于您在解析过程中的错误,我认为它来自这样一个事实,即您用 int64 写入数据(用 64 位编码的整数,所以 8 个字节(并在 uint8 中读取它们(用 8 位编码的无符号整数,所以 1 个字节(。如果在 int64 或 int8 中定义,则同一整数将具有不同的字节序列。使用 tfrecord 文件时,以字节为单位写入映像是一种很好的做法,但您也需要使用正确的类型以字节为单位读取它们。

对于您的代码,请尝试改为image = tf.decode_raw(features['image_raw'], tf.int64)

错误似乎在这里。

#Convert from a scalar string tensor to a uint8 tensor
image = tf.decode_raw(features['image_raw'], tf.uint8)#the image looks like the tensor with 1302936 values.
image.set_shape([self.input_rows*self.input_cols*self.num_filters])#self.input_rows*self.input_cols*self.num_filters equals 162867, right?

这就是我的全部猜测,因为您提供的代码太少了。

相关内容

  • 没有找到相关文章

最新更新