TF-Keras-多输入功能API模型的Dataset.from_generator



我有一个生成三个变量的生成器。前两个变量是两输入Keras模型(函数API(的两个输入。我正在使用TF数据集来提供我的模型。代码如下:


train_dataset = tf.data.Dataset.from_generator(generator=make_generator_train,
args=[train_x_paths, train_y_int],
output_types=(tf.tuple((tf.float16, tf.float16)), tf.int8),
output_shapes=(tf.TensorShape([2]),
tf.TensorShape([1]))).batch(batch_size=batch_size)

我得到了一个TypeError,上面写着:

TypeError:如果浅结构是序列,则输入也必须是序列。输入的类型为:<类'tensorflow.python.framework.tensor_shape.TensorShape'>。

这样试试:

train_dataset = tf.data.Dataset.from_generator(
generator=make_generator_train,
args=[train_x_paths, train_y_int],
output_types=(tf.float16, tf.int8)
).batch(batch_size=batch_size)

大多数情况下,您不需要指定output_shapes。它是在运行时决定的。此外,您只需要在output_types中指定输出张量的总体dtype。不是每个张量维度的数据类型。

解决方案:生成器应该为输入和输出生成一个字典。

最新更新