TFLiteConverter representative_dataset from keras.preprocess



我通过

得到了一个数据集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=validation_split,
subset="training",
seed=seed,
image_size=(img_height, img_width),
batch_size=batch_size)

(基于https://www.tensorflow.org/tutorials/load_data/images的代码,对配置进行了很小的更改)

我正在将最终模型转换为TFLite模型,该模型正在工作,但我认为该模型对于终端设备来说太大了,因此我试图通过提供representative_dataset(如https://www.tensorflow.org/lite/performance/post_training_quantization)来运行训练后量化

然而,我不知道如何将image_dataset_from_directory生成的数据集转换为representative_dataset所期望的格式

提供的示例包含

def representative_dataset():
for data in tf.data.Dataset.from_tensor_slices((images)).batch(1).take(100):
yield [data.astype(tf.float32)]

我试过

def representative_dataset():
for data in train_ds.batch(1).take(100):
yield [data.astype(tf.float32)]

但这不是它

看起来像

def representative_dataset():
for image_batch, labels_batch in train_ds:
yield [image_batch]

是我正在寻找的,image_batch已经是tf.float32

我没能让tf.keras.preprocessing.image_dataset_from_directory工作,但我有一些运气与tf.keras.preprocessing.ImageDataGenerator

在我的例子中,图像位于'images/all'目录中。我必须确保从该目录中删除所有非图像文件(例如XML注释)。

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet import preprocess_input
def representative_dataset():
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
test_generator = test_datagen.flow_from_directory(
'./images', 
target_size=(300, 300), 
batch_size=1,
classes=['all'],
class_mode='categorical')
for ind in range(len(test_generator.filenames)):
img_with_label = test_generator.next()
yield [np.array(img_with_label[0], dtype=np.float32, ndmin=2)]

相关内容

  • 没有找到相关文章

最新更新