(prototype) FX Graph Mode Post Training Static Quantization¶
Author: Jerry Zhang
This tutorial introduces the steps to do post training static quantization in graph mode based on
torch.fx.
The advantage of FX graph mode quantization is that we can perform quantization fully automatically on the model
although there might some effort required to make the model compatible with FX Graph Mode Quantizatiion (symbolically traceable with torch.fx
),
we’ll have a separate tutorial to show how to make the part of the model we want to quantize compatibble with FX Graph Mode Quantization.
We also have a tutorial for FX Graph Mode Post Training Dynamic Quantization.
tldr; The FX Graph Mode API looks like the following:
import torch
from torch.quantization import get_default_qconfig
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
from torch.quantization.quantize_fx import prepare_fx, convert_fx
float_model.eval()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
prepared_model = prepare_fx(float_model, qconfig_dict) # fuse modules and insert observers
calibrate(prepared_model, data_loader_test) # run calibration on sample data
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
1. Motivation of FX Graph Mode Quantization¶
Currently PyTorch only has eager mode quantization: Static Quantization with Eager Mode in PyTorch.
We can see there are multiple manual steps involved in the process, including:
Explicitly quantize and dequantize activations, this is time consuming when floating point and quantized operations are mixed in a model.
Explicitly fuse modules, this requires manually identifying the sequence of convolutions, batch norms and relus and other fusion patterns.
Special handling is needed for pytorch tensor operations (like add, concat etc.)
Functionals did not have first class support (functional.conv2d and functional.linear would not get quantized)
Most of these required modifications comes from the underlying limitations of eager mode quantization. Eager mode works in module level since it can not inspect the code that is actually run (in the forward function), quantization is achieved by module swapping, and we don’t know how the modules are used in forward function in eager mode, so it requires users to insert QuantStub and DeQuantStub manually to mark the points they want to quantize or dequantize. In graph mode, we can inspect the actual code that’s been executed in forward function (e.g. aten function calls) and quantization is achieved by module and graph manipulations. Since graph mode has full visibility of the code that is run, our tool is able to automatically figure out things like which modules to fuse and where to insert observer calls, quantize/dequantize functions etc., we are able to automate the whole quantization process.
Advantages of FX Graph Mode Quantization are:
Simple quantization flow, minimal manual steps
Unlocks the possibility of doing higher level optimizations like automatic precision selection
2. Define Helper Functions and Prepare Dataset¶
We’ll start by doing the necessary imports, defining some helper functions and prepare the data. These steps are identitcal to Static Quantization with Eager Mode in PyTorch.
To run the code in this tutorial using the entire ImageNet dataset, first download imagenet by following the instructions at here ImageNet Data. Unzip the downloaded file into the ‘data_path’ folder.
Download the torchvision resnet18 model and rename it to
data/resnet18_pretrained_float.pth
.
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization
# Setup warnings
import warnings
warnings.filterwarnings(
action='ignore',
category=DeprecationWarning,
module=r'.*'
)
warnings.filterwarnings(
action='default',
module=r'torch.quantization'
)
# Specify random seed for repeatable results
_ = torch.manual_seed(191009)
from torchvision.models.resnet import resnet18
from torch.quantization import get_default_qconfig, quantize_jit
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def evaluate(model, criterion, data_loader):
model.eval()
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
cnt = 0
with torch.no_grad():
for image, target in data_loader:
output = model(image)
loss = criterion(output, target)
cnt += 1
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
print('')
return top1, top5
def load_model(model_file):
model = resnet18(pretrained=False)
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.to("cpu")
return model
def print_size_of_model(model):
if isinstance(model, torch.jit.RecursiveScriptModule):
torch.jit.save(model, "temp.p")
else:
torch.jit.save(torch.jit.script(model), "temp.p")
print("Size (MB):", os.path.getsize("temp.p")/1e6)
os.remove("temp.p")
def prepare_data_loaders(data_path):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split="train",
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset_test = torchvision.datasets.ImageNet(
data_path, split="val",
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=train_batch_size,
sampler=train_sampler)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=eval_batch_size,
sampler=test_sampler)
return data_loader, data_loader_test
data_path = '~/.data/imagenet'
saved_model_dir = 'data/'
float_model_file = 'resnet18_pretrained_float.pth'
train_batch_size = 30
eval_batch_size = 50
data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
float_model.eval()
# deepcopy the model since we need to keep the original model around
import copy
model_to_quantize = copy.deepcopy(float_model)
3. Set model to eval mode¶
For post training quantization, we’ll need to set model to eval mode.
model_to_quantize.eval()
4. Specify how to quantize the model with qconfig_dict
¶
qconfig_dict = {"" : default_qconfig}
We use the same qconfig used in eager mode quantization, qconfig
is just a named tuple
of the observers for activation and weight. qconfig_dict
is a dictionary with the following configurations:
qconfig = {
" : qconfig_global,
"sub" : qconfig_sub,
"sub.fc" : qconfig_fc,
"sub.conv": None
}
qconfig_dict = {
# qconfig? means either a valid qconfig or None
# optional, global config
"": qconfig?,
# optional, used for module and function types
# could also be split into module_types and function_types if we prefer
"object_type": [
(torch.nn.Conv2d, qconfig?),
(torch.nn.functional.add, qconfig?),
...,
],
# optional, used for module names
"module_name": [
("foo.bar", qconfig?)
...,
],
# optional, matched in order, first match takes precedence
"module_name_regex": [
("foo.*bar.*conv[0-9]+", qconfig?)
...,
],
# priority (in increasing order): global, object_type, module_name_regex, module_name
# qconfig == None means fusion and quantization should be skipped for anything
# matching the rule
# **api subject to change**
# optional: specify the path for standalone modules
# These modules are symbolically traced and quantized as one unit
# so that the call to the submodule appears as one call_module
# node in the forward graph of the GraphModule
"standalone_module_name": [
"submodule.standalone"
],
"standalone_module_class": [
StandaloneModuleClass
]
}
Utility functions related to qconfig
can be found in the qconfig file.
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
5. Prepare the Model for Post Training Static Quantization¶
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
prepare_fx folds BatchNorm modules into previous Conv2d modules, and insert observers in appropriate places in the model.
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
print(prepared_model.graph)
6. Calibration¶
Calibration function is run after the observers are inserted in the model. The purpose for calibration is to run through some sample examples that is representative of the workload (for example a sample of the training data set) so that the observers in the model are able to observe the statistics of the Tensors and we can later use this information to calculate quantization parameters.
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
calibrate(prepared_model, data_loader_test) # run calibration on sample data
7. Convert the Model to a Quantized Model¶
convert_fx
takes a calibrated model and produces a quantized model.
quantized_model = convert_fx(prepared_model)
print(quantized_model)
8. Evaluation¶
We can now print the size and accuracy of the quantized model.
print("Size of model before quantization")
print_size_of_model(float_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
fx_graph_mode_model_file_path = saved_model_dir + "resnet18_fx_graph_mode_quantized.pth"
# this does not run due to some erros loading convrelu module:
# ModuleAttributeError: 'ConvReLU2d' object has no attribute '_modules'
# save the whole model directly
# torch.save(quantized_model, fx_graph_mode_model_file_path)
# loaded_quantized_model = torch.load(fx_graph_mode_model_file_path)
# save with state_dict
# torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path)
# import copy
# model_to_quantize = copy.deepcopy(float_model)
# prepared_model = prepare_fx(model_to_quantize, {"": qconfig})
# loaded_quantized_model = convert_fx(prepared_model)
# loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path))
# save with script
torch.jit.save(torch.jit.script(quantized_model), fx_graph_mode_model_file_path)
loaded_quantized_model = torch.jit.load(fx_graph_mode_model_file_path)
top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test)
print("[after serialization/deserialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
If you want to get better accuracy or performance, try changing the qconfig_dict. We plan to add support for graph mode in the Numerical Suite so that you can easily determine the sensitivity towards quantization of different modules in a model: PyTorch Numeric Suite Tutorial
9. Debugging Quantized Model¶
We can also print the weight for quantized an un-quantized conv to see the difference,
we’ll first call fuse explicitly to fuse the conv and bn in the model:
Note that fuse_fx
only works in eval mode.
fused = fuse_fx(float_model)
conv1_weight_after_fuse = fused.conv1[0].weight[0]
conv1_weight_after_quant = quantized_model.conv1.weight().dequantize()[0]
print(torch.max(abs(conv1_weight_after_fuse - conv1_weight_after_quant)))
10. Comparison with Baseline Float Model and Eager Mode Quantization¶
scripted_float_model_file = "resnet18_scripted.pth"
print("Size of baseline model")
print_size_of_model(float_model)
top1, top5 = evaluate(float_model, criterion, data_loader_test)
print("Baseline Float Model Evaluation accuracy: %2.2f, %2.2f"%(top1.avg, top5.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)
In this section we compare the model quantized with FX graph mode quantization with the model quantized in eager mode. FX graph mode and eager mode produce very similar quantized models, so the expectation is that the accuracy and speedup are similar as well.
print("Size of Fx graph mode quantized model")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("FX graph mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
from torchvision.models.quantization.resnet import resnet18
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
print("Size of eager mode quantized model")
eager_quantized_model = torch.jit.script(eager_quantized_model)
print_size_of_model(eager_quantized_model)
top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
print("eager mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
eager_mode_model_file = "resnet18_eager_mode_quantized.pth"
torch.jit.save(eager_quantized_model, saved_model_dir + eager_mode_model_file)
We can see that the model size and accuracy of FX graph mode and eager mode quantized model are pretty similar.
Running the model in AIBench (with single threading) gives the following result:
Scripted Float Model:
Self CPU time total: 192.48ms
Scripted Eager Mode Quantized Model:
Self CPU time total: 50.76ms
Scripted FX Graph Mode Quantized Model:
Self CPU time total: 50.63ms
As we can see for resnet18 both FX graph mode and eager mode quantized model get similar speed up over the floating point model, which is around 2-4x faster than the floating point model. But the actual speedup over floating point model may vary depending on model, device, build, input batch sizes, threading etc.