如何正确地将tf.function与TensorFlow数据集一起使用



我正在尝试将TF数据集与@tf.函数一起使用,对图像目录进行一些预处理。在tf函数中,图像文件被读取为 RAW 字符串张量,我正在尝试从该张量中获取切片。切片(前 13 个字符(表示有关 .ppm 图像(标题(的信息。我收到一个错误:ValueError: Shape must be rank 1 but is rank 0 for 'Slice' (op: 'Slice') with input shapes: [], [1], [1].最初我试图直接切片张量的 .numpy(( 属性(filepathtf 函数的输入参数(,但我认为在tf函数中这样做在语义上是错误的。它也不起作用,因为filepath输入张量没有 numpy(( 属性(我不明白为什么??在tf函数之外,例如在 jupyter 笔记本单元格中,我可以迭代数据集并获取具有 numpy 属性的单个项目,并对其进行切片和所有后续处理。我确实意识到我对 TF 工作原理的理解可能存在差距(我正在使用 TF 2.0(,所以我希望有人可以澄清我在阅读中遗漏的内容。tf函数的目的是将 ppm 图像转换为 png,因此此功能有一个副作用,但我没有走那么远来找出这是否可以做到。

代码如下:

@tf.function
def ppm_to_png(filepath):
ppm_bytes = tf.io.read_file(filepath) #.numpy()
bytes_header = tf.slice(ppm_bytes, [0], [13])
# bytes_header = ppm_bytes[:13].eval()  # this did not work either with similar error msg
.
.
.
import glob
files = glob.glob(os.path.join(data_dir, '00000/*.ppm'))
dataset = tf.data.Dataset.from_tensor_slices(files)
png_filepaths = dataset.map(ppm_to_png, num_parallel_calls=tf.data.experimental.AUTOTUNE)

若要操作 TF 中的字符串值,请查看 tf.string 命名空间。

在这种情况下,您可以使用tf.strings.substr

@tf.function
def ppm_to_png(filepath):
ppm_bytes = tf.io.read_file(filepath)
bytes_header = tf.strings.substr(ppm_bytes, 0, 13)
tf.print(bytes_header)

tf.slice只对张量对象进行操作,不对它们的元素进行操作。这里,ppm_bytes是一个标量张量,包含一个类型为tf.string的元素,其值是文件的全部字符串内容。因此,当您调用tf.slice时,它只查看标量位,并且不够聪明,无法意识到您实际上想要获取该元素的一部分。

最新更新