为什么我会收到类型错误"Image data cannot be converted to float"?



加载图像时,我试图通过将它们打印在Pyplot中,以确保它们正确加载,但是我有问题。如何将这些图像加载到TensorFlow中,并使用Pyplot的imshow()(或其他方式)检查它们?

图像数据是单渠道(黑白)JPEG。它最初被加载为具有未知形状和UINT8 dtype的张量。我已经尝试确保将张量重塑为正确的形状并铸造为float32。我还尝试确保将值从0.0-1.0缩放为浮点,并使用imshow()函数中的灰色CMAPPING缩放。

import tensorflow as tf
import matplotlib.pyplot as plt
def load_and_preprocess_jpeg(imagepath):
    img = tf.read_file(imagepath)
    img_tensor = tf.image.decode_jpeg(img)
    img_tensor.set_shape([792,1224,1])
    img_tensor = tf.reshape(img_tensor, [792,1224])
    img_tensor = tf.cast(img_tensor, tf.float32, name='ImageCast')
    #img_tensor /= 255.0 #Tried with and without
    return img_tensor
def read_data(all_filenames):
    path_Dataset = tf.data.Dataset.from_tensor_slices(all_filenames)
    image_Dataset = path_Dataset.map(load_and_preprocess_jpeg)
    plt.figure(figsize=(8,8))
    temp_DS = image_Dataset.take(4)
    itera = temp_DS.make_one_shot_iterator()
    for n in range(4):
        image = itera.get_next()
        plt.subplot(2,2,n+1)
        plt.imshow(image)
        plt.grid(False)
        plt.xticks([])
        plt.yticks([])

我的堆栈跟踪:

File "<stdin>", line 1, in <module>
line 34, in read_data
  plt.imshow(image)
matplotlibpyplot.py, line 3205, in imshow
  **kwargs)
matplotlib__init__.py, line 1855, in inner
  return func(ax, *args, **kwargs)
matplotlibaxes_axes.py, line 5487, in imshow
  im.set_data(X)
matplotlibimage.py, line 649, in set_data
  raise TypeError("Image data cannot be converted to float")

您正在尝试绘制张量。为了绘制图像,您必须先运行会话。尝试以下代码:

import tensorflow as tf
import matplotlib.pyplot as plt
def load_and_preprocess_jpeg(imagepath):
    img = tf.read_file(imagepath)
    img_tensor = tf.image.decode_jpeg(img)
    img_tensor = tf.image.resize_images(img_tensor, [img_size,img_size])
    img_tensor = tf.cast(img_tensor, tf.float32, name='ImageCast')
    img_tensor /= 255.0 
    return img_tensor
path_Dataset = tf.data.Dataset.from_tensor_slices(all_filenames)
image_Dataset = path_Dataset.map(load_and_preprocess_jpeg)
temp_DS = image_Dataset.take(4)
itera = temp_DS.make_one_shot_iterator()
image = itera.get_next()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    while True:
        try:
            image_to_plot = sess.run(image)
            plt.figure(figsize=(8,8))
            plt.subplot(2,2,n+1)
            plt.imshow(image_to_plot)
            plt.grid(False)
            plt.xticks([])
            plt.yticks([])
        except tf.errors.OutOfRangeError:
            break 

相关内容

最新更新