完整的警告是:警告:tensorflow:您的输入数据不足;中断培训。确保您的数据集或生成器至少可以生成steps_per_epoch * epochs
个批次(在本例中为3400个批次)。在构建数据集时,可能需要使用repeat()函数。
# importing libraries
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K
import tensorflow as tf
train_data_dir = 'marvel/train'
validation_data_dir = 'marvel/valid'
nb_train_samples = 2584
nb_validation_samples = 451
epochs = 100
batch_size_train = 76
batch_size_val = 41
if K.image_data_format() == 'channels_first':
input_shape = (3, 200, 200)
else:
input_shape = (200, 200, 3)
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(200, 200, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(8, activation='softmax')
])
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(train_data_dir,
target_size=(200, 200),
batch_size=batch_size_train,
classes=['black widow', 'captain america', 'doctor strange', 'hulk', 'iron man', 'loki', 'spiderman', 'thanos'],
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(200, 200),
batch_size=batch_size_val, class_mode='categorical')
model.fit(train_generator,
steps_per_epoch=nb_train_samples // batch_size_train,
epochs=epochs, validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size_val)
model.save_weights('characterImg.h5')
print("Saved model characterImg.h5")
以上是我的代码。谁能帮我理解这个错误到底是什么意思?我用它有很多麻烦。谢谢你!(如果你需要更多的信息,请告诉我)
好吧,我不确定这是否适用于每个人,但为了解决这个问题,我只是删除了
行steps_per_epoch=nb_train_samples // batch_size_train,
,它成功了。我意识到的一切都不是理想的,但对于那些寻找一个绝望的解决方案,这可能会做你
看起来你的数据集长度小于你的nb_train_samples
/nb_validation_samples
。
拟合前添加repeat()
调用:
train_generator = train_generator.repeat()
validation_generator = validation_generator.repeat()