Tensorflow1和Tensorflow2中的批处理



我正试图将图像单应性代码从TF1版本转换为TF2版本,但这里的TF脚本转换不起作用。我一直在处理数据集,因为图像、image_patch和image_Indices具有不同的形状。虽然TF1在接收和批处理数据集的包时没有问题,但TF2有问题。

imgs= np.random.rand(11,240,320,3)
pts = np.random.randint(100, size =(11,8))
patch = np.random.rand(11,128,128,1)
imgs = tf.convert_to_tensor(imgs)
pts = tf.convert_to_tensor(pts)
patch = tf.convert_to_tensor(patch)
pts= tf.cast(pts,dtype=tf.float64)

tensorflow2:

img_batch,pts_batch,patch_batch = tf.data.Dataset.from_tensor_slices([imgs,pts,patch]).shuffle(buffer_size=batch_size*4)

这里,11是图像的数量,240和320是图像尺寸,3是通道的数量。

错误-

tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [11,240,320,3] != values[2].shape = [11,128,128,1] [Op:Pack] name: component_0

tensorflow1:

tf.compat.v1.train.batch([imgs,pts,patch], batch_size=5)

输出-

[<tf.Tensor 'batch_2:0' shape=(5, 11, 240, 320, 3) dtype=float64>,
<tf.Tensor 'batch_2:1' shape=(5, 11, 8) dtype=float64>,
<tf.Tensor 'batch_2:2' shape=(5, 11, 128, 128, 1) dtype=float64>]

如何在tensorflow2中批量处理不同维度的数据集?同样运行;tf.compat.v1.train.batch(("在TF2(tensoflow版本2.3(中不起作用,因为它给出了紧急执行错误。

在TF2中批处理此类数据集的正确方法是什么?

这里的问题不是批处理,而是tf.data.Dataset本身的生成。错误是由img_batch,pts_batch,patch_batch = tf.data.Dataset.from_tensor_slices([imgs,pts,patch])引起的,而不是由.shuffle(batch_size=...)引起的。

我认为.from_tensor_slices在这里的水平太高了,看看tf.data.Dataset.from_generator

最新更新