如何检索tf.image.decode_jpeg返回的图像张量和高度?



我尝试设置一个图像管道,为裁剪图像的Tensorflow构建图像数据集。 我遵循了本教程,但我想将文件裁剪为正方形,而不是在不保留纵横比的情况下调整其大小。 我不知道如何获得它们的尺寸。

#
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
#
import glob

AUTOTUNE = tf.data.experimental.AUTOTUNE
IMAGE_SIZE = 192

def preprocess_image(path):
img_raw = tf.io.read_file(path)
img_tensor = tf.image.decode_jpeg(img_raw, channels=3)
print("img_tensor")
print(img_tensor)
height = img_tensor.shape[0]
print("height")
print(height)
return img_tensor

files_path = glob.glob('./images/*.jpeg')
image_count = len(files_path)
path_ds = tf.data.Dataset.from_tensor_slices(files_path)
path_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)

tf.image.decode_jpeg返回的张量形状为:

Tensor("DecodeJpeg:0", shape=(None, None, 3), dtype=uint8)

如何访问jpg图像的大小?

当我以这种方式访问它时,它可以工作:

#
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
#
image = tf.io.read_file('./images/4c34476047bcbbfd10b1fd3342605659.jpeg/')
image = tf.image.decode_jpeg(image, channels=3)
print("image.shape")
print(image.shape)

它打印:

image.shape
(700, 498, 3)

您面临此问题是因为数据集是延迟加载的(仅在需要时进行评估)。

从本质上讲,tf 只有在读取文件或我们作为开发人员告诉它时才能"知道"图像的大小。这似乎是一个显而易见的观点,但值得牢记。

因此,鉴于 tfDataset对象可以表示任意大的数据序列(事实上,以这种方式表示无限数据集是完全合理的),根据设计,它不会预先读取文件。相反,每当我们的下游代码需要新示例或批处理时,它都会读取它们。

恐怕我们真的需要知道图像的大小或预先针对所有可能的大小进行编码。

附言你可以让第二种方法工作的原因是它急切地评估(单个)张量示例。

附言您可能已经知道,您可以使用tf.shape()执行时"评估"任何张量的形状(并在数据集预处理管道中使用它的结果),但您无法预先检查它

我们可以做到。关键是投射tf.shape()的返回,这似乎是None,直到张量流图被执行。

以下代码调整图像的大小,保留纵横比,以便256高度或宽度(以较短者为准),然后随机裁剪为224x224

def preprocess(filename, label):
image = tf.image.decode_jpeg(tf.io.read_file(filename), channels=3)
# Resize the image by converting the smaller edge to 256 
shape = tf.shape(image)
_h, _w = shape[0], shape[1] 
_h, _w = tf.cast(_h, tf.float32), tf.cast(_w, tf.float32)
ratio = tf.math.divide(tf.constant(256.), tf.math.minimum(_h, _w))
ratio = tf.cast(ratio, tf.float32)
image = tf.image.resize(
image, tf.cast([_h*ratio, _w*ratio], tf.int32)
)   
image = tf.image.random_crop(image, [224,224,3])
return image, label

我将其用于ImageNet。

最新更新