加载numpy数组到Tensorflow输入管道



所以我是按照一个教程为图像制作数据加载器(https://github.com/codebasics/deep-learning-keras-tf-tutorial/blob/master/44_tf_data_pipeline/tf_data_pipeline.ipynb)。

完整代码是这样的:

images_ds = tf.data.Dataset.list_files("path/class/*")
def get_label(file_path):
import os
parts = tf.strings.split(file_path, os.path.sep)
return parts[-2]
## How the tutorial does it
def process_image(file_path):
label = get_label(file_path)
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img)
return img, label
## How I want to do it
def process_image(file_path):
label = get_label(file_path)

img = np.load(file_path)
img = tf.convert_to_tensor(img) 
return img, label
train_ds = images_ds.map(process_image)

在本教程中,数据是.jpeg格式。但是,我的数据是。npy.

因此,用以下代码加载数据不起作用:

img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img)

我想解决这个问题,但是我的解决方案行不通。

img = np.load(file_path)
img = tf.convert_to_tensor(img) 

当我提供process_image函数1实例时,它确实工作。但是,当我使用.map函数时,我得到一个错误。

错误:
TypeError: expected str, bytes or os。PathLike对象,不是张量

是否有一个等效的函数tf.image.decode_image()解码numpy数组和/或有人可以帮助我与我当前的错误?

@ andr的评论给了我正确的方向。下面的代码可以工作。


def process_image(file_path):
label = get_label(file_path)
label = np.uint8(label)
img = np.load(file_path)
img = tf.convert_to_tensor(img/255, dtype=tf.float32) 
return img , label 
train_ds = images_ds.map(lambda item: tf.numpy_function(
process_image, [item], (tf.float32, tf.uint8))) 

相关内容

最新更新