不缓存地复制Keras教程



我试图在官方Keras网站上复制这个教程。本教程是关于迁移学习的,它是一个关于如何在著名的猫对狗数据集上使用预训练模型的指导示例。

我的问题是关于他们做缓存和调整缓冲区大小的部分,执行如下:

batch_size = 32
train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

如果我跳过这一部分,我就不能再复制教程了,因为我在适合模型的地方会出现错误。错误如下:

ValueError: Input 0 is incompatible with layer model_7: expected shape=(None, 150, 150, 3), found shape=(150, 150, 3)

我需要做什么修改来运行训练而不用担心缓存和相关的东西?

要从这段代码中删除缓存和预取(使用缓存机制),只需从输入管道中删除这些方法,如下所示:

batch_size = 32
train_ds = train_ds.batch(batch_size)
validation_ds = validation_ds.batch(batch_size)
test_ds = test_ds.batch(batch_size)

教程中的其余代码将正常工作,除了数据集将不使用缓存。

正如其他人所指出的那样,.batch()方法独立于缓存,但需要将数据集分批排列并依次提供给模型。当您跳过整个代码块时,您收到的错误与跳过.batch()方法有关。

.batch()为数据集添加了一个外部维度,标准keras模型和层(如该教程中使用的那些)期望作为输入。这就是为什么你会得到错误"预期形状=(None, 150,150,3),发现形状=(150,150,3)"

您可以在这里阅读更多关于方法链的信息,在这里阅读批处理和.batch()方法,.cache()方法和.prefetch()方法。

相关内容

  • 没有找到相关文章

最新更新