我使用Keras imagedataggenerator .flow_from_directory(…)创建训练和测试数据集。然后我想用这些数据拟合model.fit()
。在Tensorflow 2.1中,它工作得非常好。但是,在Tensorflow 2.2中运行相同的代码会生成:TypeError: data type not understood
。您建议如何克服这个问题并运行TF2.2?
代码示例:
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255., dtype=tf.float32)
train_data = train_gen.flow_from_directory(directory=os.path.join(current_dir, data, 'train/'), target_size=(width, height), class_mode='sparse')
...
model.fit(train_data, epochs=50) # This generates an error in TF2.2, but in TF2.1 works fine.
在TF2.2中产生此错误的另一种方法是遍历生成器:
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255., dtype=tf.float32)
train_data = train_gen.flow_from_directory(directory=os.path.join(current_dir, data, 'train/'), target_size=(width, height), class_mode='sparse')
for x,y in train_data:
print(type(x), type(y))
问题出在keras版本。以下配置导致错误。
keras 2.3.1
keras-preprocessing 1.1.2
更改到这个版本后一切正常:
keras 2.4.3
keras-preprocessing 1.1.0