我通过
得到了一个数据集train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=validation_split,
subset="training",
seed=seed,
image_size=(img_height, img_width),
batch_size=batch_size)
(基于https://www.tensorflow.org/tutorials/load_data/images的代码,对配置进行了很小的更改)
我正在将最终模型转换为TFLite模型,该模型正在工作,但我认为该模型对于终端设备来说太大了,因此我试图通过提供representative_dataset
(如https://www.tensorflow.org/lite/performance/post_training_quantization)来运行训练后量化
然而,我不知道如何将image_dataset_from_directory
生成的数据集转换为representative_dataset
所期望的格式
提供的示例包含
def representative_dataset():
for data in tf.data.Dataset.from_tensor_slices((images)).batch(1).take(100):
yield [data.astype(tf.float32)]
我试过
def representative_dataset():
for data in train_ds.batch(1).take(100):
yield [data.astype(tf.float32)]
但这不是它
看起来像
def representative_dataset():
for image_batch, labels_batch in train_ds:
yield [image_batch]
是我正在寻找的,image_batch已经是tf.float32
我没能让tf.keras.preprocessing.image_dataset_from_directory
工作,但我有一些运气与tf.keras.preprocessing.ImageDataGenerator
。
在我的例子中,图像位于'images/all'目录中。我必须确保从该目录中删除所有非图像文件(例如XML注释)。
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet import preprocess_input
def representative_dataset():
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
test_generator = test_datagen.flow_from_directory(
'./images',
target_size=(300, 300),
batch_size=1,
classes=['all'],
class_mode='categorical')
for ind in range(len(test_generator.filenames)):
img_with_label = test_generator.next()
yield [np.array(img_with_label[0], dtype=np.float32, ndmin=2)]