Connecting BatchDataset with Keras VGG16 preprocess_input



我使用tf.keras.preprocessing.image_dataset_from_directory来获得BatchDataset,其中数据集有10个类。

我正试图将这个BatchDataset与KerasVGG16(docs)网络集成。来自文档:

注意:每个Keras应用程序期望一种特定类型的输入预处理。对于VGG16,在将它们传递给模型之前,对输入调用tf.keras.applications.vgg16.preprocess_input

然而,我正在努力让这个preprocess_inputBatchDataset一起工作。你能帮我弄清楚如何连接这两个点吗?

请看下面的代码:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(train_data_dir, image_size=(224, 224))
train_ds = tf.keras.applications.vgg16.preprocess_input(train_ds)

这将抛出TypeError: 'BatchDataset' object is not subscriptable:

Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: 'BatchDataset' object is not subscriptable

From TypeError: 'DatasetV1Adapter' object is not subscriptable (From BatchDataset not subscriptable当尝试将Python字典格式化为表时)建议使用:

train_ds = tf.keras.applications.vgg16.preprocess_input(
list(train_ds.as_numpy_iterator())
)

然而,这也失败了:

Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: list indices must be integers or slices, not tuple

这些都是使用Python==3.10.3tensorflow==2.8.0

我怎样才能使它工作?提前谢谢你。

好吧,我明白了。我需要通过tf.Tensor,而不是tf.data.Dataset。可以通过迭代Dataset得到Tensor

这可以通过以下几种方式完成:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(...)
# Option 1
batch_images = next(iter(train_ds))[0]
preprocessed_images = tf.keras.applications.vgg16.preprocess_input(batch_images)
# Option 2:
for batch_images, batch_labels in train_ds:
preprocessed_images = tf.keras.applications.vgg16.preprocess_input(batch_images)

如果将选项2转换为生成器,则可以直接传递到下游model.fit。干杯!

相关内容

  • 没有找到相关文章

最新更新