tf.GradientTape 的 TensorFlow 2 量化感知训练 (QAT)

TensorFlow 2 Quantization Aware Training (QAT) with tf.GradientTape

任何人都可以指出可以学习如何在 TensorFlow 2 上使用 tf.GradientTape 执行量化感知训练 (QAT) 的参考资料吗?

我只看到 tf.keras API 完成了此操作。我不使用 tf. keras,我总是使用 tf.GradientTape 构建自定义训练,以提供对训练过程的更多控制。我现在需要量化模型,但我只看到有关如何使用 tf. keras API.

进行量化的参考资料

在官方示例here中,他们用model. fit展示了QAT训练。这是使用 tf.GradientTape() 进行 量化感知训练 的演示。但为了完整参考,让我们在此处同时进行。


基础模型训练。这直接来自 official doc。详情请查看

import os
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)
10ms/step - loss: 0.5411 - accuracy: 0.8507 - val_loss: 0.1142 - val_accuracy: 0.9705
<tensorflow.python.keras.callbacks.History at 0x7f9ee970ab90>

QAT .fit.

现在,对基本模型执行 QAT

# -----------------------
# ------------- Quantization Aware Training -------------
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
train_images_subset = train_images[0:1000] 
train_labels_subset = train_labels[0:1000]
q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)


356ms/step - loss: 0.1431 - accuracy: 0.9629 - val_loss: 0.1626 - val_accuracy: 0.9500
<tensorflow.python.keras.callbacks.History at 0x7f9edf0aef90>

检查性能

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Baseline test accuracy: 0.9660999774932861
Quant test accuracy: 0.9660000205039978

QAT tf.GradientTape().

这是基础模型上的 QAT 训练部分。请注意,我们还可以对基本模型执行自定义训练。

batch_size = 500

train_dataset = tf.data.Dataset.from_tensor_slices((train_images_subset,
                                                     train_labels_subset))
train_dataset = train_dataset.batch(batch_size=batch_size, 
                                    drop_remainder=False)

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

for epoch in range(1):
    for x, y in train_dataset:
        with tf.GradientTape() as tape:
            preds = q_aware_model(x, training=True)
            loss = loss_fn(y, preds)
        grads = tape.gradient(loss, q_aware_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, q_aware_model.trainable_variables))
        
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9660999774932861
Quant test accuracy: 0.9645000100135803