<keras.preprocessing.image.DirectoryIterator> Object 返回 TypeError: +: 'int' 和 'str' 的不受支持的



我使用了下面的代码片段来创建用于训练和验证生成器的<keras.preprocessing.image.DirectoryIterator>对象。

class DataLoader:
@staticmethod
def load_data(data_config, prefix = "blond"):

train_datagen =  tf.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function=tf.keras.applications.inception_v3.preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function=tf.keras.applications.inception_v3.preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
train_generator = train_datagen.flow_from_directory(
'data/celeba-dataset/{}-train'.format(prefix),
target_size=(data_config.data.IMG_HEIGHT, data_config.data.IMG_WIDTH),
batch_size=data_config.train.BATCH_SIZE)

然后我创建另一个包含以下load_datatrain函数的Model类

class Model(BaseModel):
def __init__(self, config):
super().__init__(config)
self.img_height = int(self.config.data.IMG_HEIGHT)
self.img_width = int(self.config.data.IMG_WIDTH)
self.base_model = tf.keras.applications.InceptionV3(weights='imagenet',
include_top=False,
input_shape=(self.img_height, self.img_width, 3))
self.model = None
self.training_samples = int(self.config.data.TRAINING_SAMPLES)
self.batch_size = int(self.config.train.BATCH_SIZE)
self.steps_per_epoch = int(self.training_samples) // int(self.batch_size)
self.num_epochs = int(self.config.train.EPOCHS)
self.train_generator = None
self.validation_generator = None

def load_data(self):
self.train_generator, self.validation_generator = DataLoader().load_data(self.config)
def train(self):
self.model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9), 
loss='categorical_crossentropy', 
metrics=[tf.keras.metrics.TopKCategoricalAccuracy(k = 1)])
checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath='weights.best.inc.blond.hdf5', 
verbose=1, save_best_only=True)
model_history = self.model.fit(self.train_generator,
validation_data = self.validation_generator,
steps_per_epoch= self.steps_per_epoch,
epochs = self.num_epochs,
class_weight='auto',
callbacks=[checkpointer])
return model_history.history['loss'], model_history.history['val_loss']

在运行以下代码时,

model = Model(CFG)
model.load_data()
model.build()
model.train()

得到以下Traceback

Traceback (most recent call last):
File "/Users/sauravmaheshkar/github/compression/train.py", line 14, in <module>
model.train()
File "/Users/sauravmaheshkar/github/compression/model/ForgetModel.py", line 70, in train
callbacks=[checkpointer, ],
File "/Users/sauravmaheshkar/opt/anaconda3/envs/compression/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1064, in fit
steps_per_execution=self._steps_per_execution)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/compression/lib/python3.6/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1112, in __init__
model=model)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/compression/lib/python3.6/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 898, in __init__
self._size = len(x)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/compression/lib/python3.6/site-packages/keras_preprocessing/image/iterator.py", line 68, in __len__
return (self.n + self.batch_size - 1) // self.batch_size  # round up
TypeError: unsupported operand type(s) for +: 'int' and 'str'

Config.py文件

"""Project Config in JSON"""
CFG = {
"data": {
"data_folder" : "data/CelebA/",
"images_folder" : "data/CelebA/img_align_celeba/img_align_celeba/",
"IMG_HEIGHT": "218",
"IMG_WIDTH": "178",
"TRAINING_SAMPLES": "10000",
"VALIDATION_SAMPLES": "2000",
"TEST_SAMPLES": "2000",
},
"train": {
"BATCH_SIZE": "64",
"EPOCHS": "10",
}   
}

包版本
  • tensorflow==2.4.1
  • Keras-Preprocessing==1.1.2

预期输出运行正常tensorflow训练循环

我认为问题是在DataLoader类,在load_data函数。具体来说,在这一行:

batch_size=data_config.train.BATCH_SIZE.

那一行,你的数据配置文件告诉我,你只需要添加int,你的问题就解决了。因此,将这一行替换为:

batch_size=int(data_config.train.BATCH_SIZE)

我建议你也检查一下其他参数。

或者,我认为只是从Json文件中的int值中删除引号也可以.

如你所见,数字/整型不需要引号。

相关内容

  • 没有找到相关文章

最新更新