无法在 tensorRT 中为 QAT 模型创建校准缓存

Cannot create the calibration cache for the QAT model in tensorRT

我训练了一个量化模型(借助 pytorch 中的量化感知训练方法)。我想创建校准缓存以通过 TensorRT 在 INT8 模式下进行推理。创建 calib 缓存时,我收到以下警告并且未创建缓存:

[03/06/2022-08:14:07] [TRT] [W] Calibrator won't be used in explicit precision mode. Use quantization aware training to generate network with Quantize/Dequantize nodes.
[03/06/2022-08:14:11] [TRT] [W] Some weights are outside of int8_t range and will be clipped to int8_t range.
[03/06/2022-08:14:11] [TRT] [W] Some weights are outside of int8_t range and will be clipped to int8_t range.
[03/06/2022-08:14:11] [TRT] [W] Some weights are outside of int8_t range and will be clipped to int8_t range.
[03/06/2022-08:14:11] [TRT] [W] Some weights are outside of int8_t range and will be clipped to int8_t range.

我已经对模型进行了相应的训练并转换为 ONNX:

import os
import sys
import argparse
import warnings
import collections

import torch
import torch.utils.data
from torch import nn

from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.hub import load_state_dict_from_url
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import quant_modules
import onnxruntime
import numpy as np
import models
import kornia
from prettytable import PrettyTable

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def get_parser():
    """
    Creates an argument parser.
    """
    parser = argparse.ArgumentParser(description='Classification quantization flow script')

    parser.add_argument('--data-dir', '-d', type=str, help='input data folder', required=True)
    parser.add_argument('--model-name', '-m', default='', help='model name: default resnet50')
    parser.add_argument('--disable-pcq', '-dpcq', action="store_true", help='disable per-channel quantization for weights')
    parser.add_argument('--out-dir', '-o', default='/tmp', help='output folder: default /tmp')
    parser.add_argument('--print-freq', '-pf', type=int, default=20, help='evaluation print frequency: default 20')
    parser.add_argument('--threshold', '-t', type=float, default=-1.0, help='top1 accuracy threshold (less than 0.0 means no comparison): default -1.0')

    parser.add_argument('--batch-size-train', type=int, default=8, help='batch size for training: default 128')
    parser.add_argument('--batch-size-test', type=int, default=8, help='batch size for testing: default 128')
    parser.add_argument('--batch-size-onnx', type=int, default=20, help='batch size for onnx: default 1')

    parser.add_argument('--seed', type=int, default=12345, help='random seed: default 12345')

    checkpoint = parser.add_mutually_exclusive_group(required=True)
    checkpoint.add_argument('--ckpt-path', default='', type=str, required=False,
                            help='path to latest checkpoint (default: none)')
    checkpoint.add_argument('--ckpt-url', default='', type=str, required=False,
                            help='url to latest checkpoint (default: none)')
    checkpoint.add_argument('--pretrained', action="store_true")

    parser.add_argument('--num-calib-batch', default=8, type=int,
                        help='Number of batches for calibration. 0 will disable calibration. (default: 4)')
    parser.add_argument('--num-finetune-epochs', default=0, type=int,
                        help='Number of epochs to fine tune. 0 will disable fine tune. (default: 0)')
    parser.add_argument('--calibrator', type=str, choices=["max", "histogram"], default="max")
    parser.add_argument('--percentile', nargs='+', type=float, default=[99.9, 99.99, 99.999, 99.9999])
    parser.add_argument('--sensitivity', action="store_true", help="Build sensitivity profile")
    parser.add_argument('--evaluate-onnx', action="store_true", help="Evaluate exported ONNX")

    return parser

def prepare_model(
        model_name,
        num_class,
        data_dir,
        per_channel_quantization,
        batch_size_train,
        batch_size_test,
        batch_size_onnx,
        calibrator,
        pretrained,
        ckpt_path,
        ckpt_url=None):
   

    ## Initialize quantization, model and data loaders
    if per_channel_quantization:
        print('<<<<<<< Per channel qaunt >>>>>>>>')
        quant_desc_input = QuantDescriptor(calib_method=calibrator)
        quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
        quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
    else:
        ## Force per tensor quantization for onnx runtime
        print('<<<<<<< Per tensor qaunt >>>>>>>>')
        quant_desc_input = QuantDescriptor(calib_method=calibrator, axis=None)
        quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
        quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input)
        quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

        quant_desc_weight = QuantDescriptor(calib_method=calibrator, axis=None)
        quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight)
        quant_nn.QuantConvTranspose2d.set_default_quant_desc_weight(quant_desc_weight)
        quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight)

    if model_name in models.__dict__:
        model = models.__dict__[model_name](pretrained=pretrained, quantize=True)
        num_feats = model.fc.in_features
        model.fc = nn.Linear(num_feats, num_class)

    else:
        print('Model is not available....downloading....')
        quant_modules.initialize()
        model = torchvision.models.__dict__[model_name](pretrained=pretrained)
        if 'resnet' in model_name:
            num_feats = model.fc.in_features
            model.fc = nn.Linear(num_feats, num_class)
        if 'densenet' in model_name:
            num_feats =  model.classifier.in_features
            model.classifier = nn.Linear(num_feats, num_class)
        quant_modules.deactivate()

    if not pretrained:
        if ckpt_path:
            model = torch.load(ckpt_path)
        else:
            model = load_state_dict_from_url(ckpt_url)
        if 'state_dict' in model.keys():
            model = model['state_dict']
        elif 'model' in model.keys():
            model = model['model']
        # model.load_state_dict(checkpoint)
    model.eval()
    model.cuda()
    print(model)

    ## Prepare the data loaders
    traindir = os.path.join(data_dir, 'train')
    valdir = os.path.join(data_dir, 'test')
    _args = collections.namedtuple("mock_args", ["model", "distributed", "cache_dataset"])
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        traindir, valdir, _args(model=model_name, distributed=False, cache_dataset=False))

    data_loader_train = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size_train,
        sampler=train_sampler, num_workers=4, pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=batch_size_test,
        sampler=test_sampler, num_workers=4, pin_memory=True)

    data_loader_onnx = torch.utils.data.DataLoader(
        dataset_test, batch_size=batch_size_onnx,
        sampler=test_sampler, num_workers=4, pin_memory=True)

    return model, data_loader_train, data_loader_test, data_loader_onnx

def main(cmdline_args):
    parser = get_parser()
    args = parser.parse_args(cmdline_args)
    print(parser.description)
    print(args)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    args.model_name = 'resnet34'
    args.data_dir = '/home/Dataset/'
    args.disable_pcq = True #set true to disbale
    args.batch_size_train = 8
    args.batch_size_test = 8
    args.batch_size_onnx = 8
    args.calibrator = 'max'
    args.pretrained = True
    args.ckpt_path = ''
    args.ckpt_url = ''
    args.num_class = 5


    ## Prepare the pretrained model and data loaders
    model, data_loader_train, data_loader_test, data_loader_onnx = prepare_model(
        args.model_name,
        args.num_class,
        args.data_dir,
        not args.disable_pcq,
        args.batch_size_train,
        args.batch_size_test,
        args.batch_size_onnx,
        args.calibrator,
        args.pretrained,
        args.ckpt_path,
        args.ckpt_url)

    kwargs = {"alpha": 0.75, "gamma": 2.0, "reduction": 'mean'}
    criterion = kornia.losses.FocalLoss(**kwargs)
    
    ## Calibrate the model
    with torch.no_grad():
        calibrate_model(
            model=model,
            model_name=args.model_name,
            data_loader=data_loader_train,
            num_calib_batch=args.num_calib_batch,
            calibrator=args.calibrator,
            hist_percentile=args.percentile,
            out_dir=args.out_dir)

    
    ## Build sensitivy profile
    if args.sensitivity:
        build_sensitivity_profile(model, criterion, data_loader_test)

  
    kwargs = {"alpha": 0.75, "gamma": 3.0, "reduction": 'mean'}
    criterion = kornia.losses.FocalLoss(**kwargs)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_finetune_epochs)
    for epoch in range(args.num_finetune_epochs):
        # Training a single epch
        train_one_epoch(model, criterion, optimizer, data_loader_train, "cuda", epoch, 100)
        lr_scheduler.step()

    if args.num_finetune_epochs > 0:
        ## Evaluate after finetuning
        with torch.no_grad():
            print('Finetune evaluation:')
            top1_finetuned = evaluate(model, criterion, data_loader_test, device="cuda")
    else:
        top1_finetuned = -1.0

    ## Export to ONNX
    onnx_filename = args.out_dir + '/' + args.model_name + ".onnx"
    top1_onnx = -1.0
    if export_onnx(model, onnx_filename, args.batch_size_onnx, not args.disable_pcq) and args.evaluate_onnx:
        ## Validate ONNX and evaluate
        top1_onnx = evaluate_onnx(onnx_filename, data_loader_onnx, criterion, args.print_freq)

    ## Print summary
    print("Accuracy summary:")
    table = PrettyTable(['Stage','Top1'])
    table.align['Stage'] = "l"
    table.add_row( [ 'Finetuned',   "{:.2f}".format(top1_finetuned) ] )
    table.add_row( [ 'ONNX',        "{:.2f}".format(top1_onnx) ] )
    print(table)

  
    ## Compare results
    if args.threshold >= 0.0:
        if args.evaluate_onnx and top1_onnx < 0.0:
            print("Failed to export/evaluate ONNX!")
            return 1
        if args.num_finetune_epochs > 0:
            if top1_finetuned >= (top1_onnx - args.threshold):
                print("Accuracy threshold was met!")
            else:
                print("Accuracy threshold was missed!")
                return 1

    return 0

def evaluate_onnx(onnx_filename, data_loader, criterion, print_freq):
    
    print("Loading ONNX file: ", onnx_filename)
    ort_session = onnxruntime.InferenceSession(onnx_filename)
    with torch.no_grad():
        metric_logger = utils.MetricLogger(delimiter="  ")
        header = 'Test:'
        with torch.no_grad():
            for image, target in metric_logger.log_every(data_loader, print_freq, header):
                image = image.to("cpu", non_blocking=True)
                image_data = np.array(image)
                input_data = image_data

                # run the data through onnx runtime instead of torch model
                input_name = ort_session.get_inputs()[0].name
                raw_result = ort_session.run([], {input_name: input_data})
                output = torch.tensor((raw_result[0]))

                loss = criterion(output, target)
                acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
                batch_size = image.shape[0]
                metric_logger.update(loss=loss.item())
                metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
                metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()

        print('  ONNXRuntime: Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
            .format(top1=metric_logger.acc1, top5=metric_logger.acc5))
        return metric_logger.acc1.global_avg

def export_onnx(model, onnx_filename, batch_onnx, per_channel_quantization):
    model.eval()
    quant_nn.TensorQuantizer.use_fb_fake_quant = True # We have to shift to pytorch's fake quant ops before exporting the model to ONNX

    if per_channel_quantization:
        opset_version = 13
    else:
        opset_version = 12


    # Export ONNX for multiple batch sizes
    print("Creating ONNX file: " + onnx_filename)
    dummy_input = torch.randn(batch_onnx, 3, 224, 224, device='cuda') #TODO: switch input dims by model
    input_names = ['input']
    
    if 'resnet' in onnx_filename:
        print('Changing last layer of resnet...')
        output_names = ['Linear[fc]']  ### ResNet34
    if 'densenet' in onnx_filename:
        print('Changing last layer of densenet...')
        output_names = ['Linear[classifier]'] #### DenseNet
    dynamic_axes = {'input': {0: 'batch_size'}}

    try:
        torch.onnx.export(model, dummy_input, onnx_filename, input_names=input_names,
                          export_params=True, output_names=output_names, opset_version=opset_version,
                          dynamic_axes=dynamic_axes, verbose=True, enable_onnx_checker=False, do_constant_folding=True)

    except ValueError:
        warnings.warn(UserWarning("Per-channel quantization is not yet supported in Pytorch/ONNX RT (requires ONNX opset 13)"))
        print("Failed to export to ONNX")
        return False

    return True

def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir):
    
    if num_calib_batch > 0:
        print("Calibrating model")
        with torch.no_grad():
            collect_stats(model, data_loader, num_calib_batch)

        if not calibrator == "histogram":
            compute_amax(model, method="max")
            calib_output = os.path.join(
                out_dir,
                F"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth")
            # torch.save(model.state_dict(), calib_output) # Just weights
            torch.save(model, calib_output) # whole model
        else:
            for percentile in hist_percentile:
                print(F"{percentile} percentile calibration")
                compute_amax(model, method="percentile")
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model, calib_output) # whole model

            for method in ["mse", "entropy"]:
                print(F"{method} calibration")
                compute_amax(model, method=method)
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth")
                # torch.save(model.state_dict(), calib_output)
                torch.save(model, calib_output)

def collect_stats(model, data_loader, num_batches):
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    # Feed data to the network for collecting stats
    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
            print(F"{name:40}: {module}")
    model.cuda()

def build_sensitivity_profile(model, criterion, data_loader_test):
    quant_layer_names = []
    for name, module in model.named_modules():
        if name.endswith("_quantizer"):
            module.disable()
            layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "")
            if layer_name not in quant_layer_names:
                quant_layer_names.append(layer_name)
    for i, quant_layer in enumerate(quant_layer_names):
        print("Enable", quant_layer)
        for name, module in model.named_modules():
            if name.endswith("_quantizer") and quant_layer in name:
                module.enable()
                print(F"{name:40}: {module}")
        with torch.no_grad():
            evaluate(model, criterion, data_loader_test, device="cuda")
        for name, module in model.named_modules():
            if name.endswith("_quantizer") and quant_layer in name:
                module.disable()
                print(F"{name:40}: {module}")

if __name__ == '__main__':
    res = main(sys.argv[1:])
    exit(res)

有关系统的更多信息:

TensorRT == 8.2
Pytorch == 1.9.0+cu111 
Torchvision == 0.10.0+cu111
ONNX == 1.9.0
ONNXRuntime == 1.8.1
pycuda == 2021

如果 ONNX 模型中有 Q/DQ 个节点,您可能不需要校准缓存,因为 Q/DQ 个节点中包含缩放和零点等量化参数。您可以在 OnnxRuntime (>= v1.9.0) 的 TensorRT 执行提供程序中直接 运行 Q/DQ ONNX 模型。