使用数据集迭代器作为 tensorflow model.fit() 的输入时出错



当我运行这段代码时:

x_train = tfds.load('ucf101', split='train', shuffle_files=True, batch_size = 64)
dim = lambda x: x['video'][:,30:40, ...]
x_train = x_train.map(dim)
model.compile(loss='mse',
optimizer=tf.keras.optimizers.Adam(1e-4),
metrics=['accuracy'])
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
history = model.fit(x_train, x_train,
epochs=100,
verbose=1,
callbacks=[cp_callback])

我收到错误:

ValueError: You passed a dataset or dataset iterator (<MapDataset shapes: (None, None, 256, 256, 3), types: tf.uint8>) as inputxto your model. In that case, you should not specify a target (y) argument, since the dataset or dataset iterator generates both input data and target data. Received: <MapDataset shapes: (None, None, 256, 256, 3), types: tf.uint8>

它是一个自动编码器,因此有意提供x_train作为输入和目标。MapDataset的维度是(批处理,框架,高度,宽度,rgb(,它不包含任何目标数据。

> 通过 https://www.tensorflow.org/api_docs/python/tf/keras/Model:

如果 x 是数据集、生成器或 keras.utils.Sequence 实例,则不应指定 y(因为目标将从迭代器/数据集中获取(。

如果要使用 model.fit((,则需要在 tf.data 数据集中指定目标,如下所示:

x_train = tfds.load('ucf101', split='train', shuffle_files=True, batch_size = 64)
def parser(x):
dim = x['video'][:,30:40, ...]
return dim, dim
x_train = x_train.map(parser)
...
history = model.fit(x_train,
epochs=100,
verbose=1,
callbacks=[cp_callback])

最新更新