来自 keras.preprocessing.image_dataset_from_directory 数据集的 TFLiteConverter representative_dataset

TFLiteConverter representative_dataset from keras.preprocessing.image_dataset_from_directory dataset

我有一个数据集来自

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=validation_split,
  subset="training",
  seed=seed,
  image_size=(img_height, img_width),
  batch_size=batch_size)

(基于 https://www.tensorflow.org/tutorials/load_data/images 中的代码,对配置进行了非常小的更改)

我正在将最终模型转换为工作正常的 TFLite 模型,但我认为该模型对于终端设备来说太大了,所以我正在尝试 运行 post 训练量化通过提供 representative_dataset (如 https://www.tensorflow.org/lite/performance/post_training_quantization

但是我无法弄清楚如何将 image_dataset_from_directory 生成的数据集转换为 representative_dataset

所期望的格式

提供的例子有

def representative_dataset():
  for data in tf.data.Dataset.from_tensor_slices((images)).batch(1).take(100):
    yield [data.astype(tf.float32)]

我试过

def representative_dataset():
  for data in train_ds.batch(1).take(100):
    yield [data.astype(tf.float32)]

但事实并非如此

看起来像

def representative_dataset():
  for image_batch, labels_batch in train_ds:
    yield [image_batch]

是我要找的,image_batch 已经是 tf.float32

我无法让 tf.keras.preprocessing.image_dataset_from_directory 开始工作,但我在 tf.keras.preprocessing.ImageDataGenerator 方面运气不错。

在我的例子中,图像位于 'images/all' 目录中。我必须确保从该目录中删除任何非图像文件(例如 XML 注释)。

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet import preprocess_input

def representative_dataset():
  test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
  test_generator = test_datagen.flow_from_directory(
      './images', 
      target_size=(300, 300), 
      batch_size=1,
      classes=['all'],
      class_mode='categorical')
  for ind in range(len(test_generator.filenames)):
    img_with_label = test_generator.next()
    yield [np.array(img_with_label[0], dtype=np.float32, ndmin=2)]