[quant] Implement PTQ for APoT FakeQuant (#81040)

### Summary:
This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight.

According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below).

### Test Plan:
Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py`

### Accuracy Stats:
8-bit (Uniform int8, APoT b = 8 k = 2)

**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5)

**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5)

**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5)

4-bit (Uniform int4, APoT b = 4 k = 2)

**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5)

**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5)

**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5)

**Full Precision model (FX Graph Mode quantized)**
Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5)

**Eager mode quantized model**
Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040
Approved by: https://github.com/jerryzh168
This commit is contained in:
asl3 2022-07-27 20:15:31 -07:00 committed by PyTorch MergeBot
parent f445c220be
commit 13ad4739a6
9 changed files with 358 additions and 13 deletions

View File

@ -73,6 +73,9 @@ ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.fake_quantize_function] [mypy-torch.ao.quantization.experimental.fake_quantize_function]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.fake_quantize]
ignore_missing_imports = True
# #
# Files with various errors. Mostly real errors, possibly some false # Files with various errors. Mostly real errors, possibly some false
# positives as well. # positives as well.

View File

@ -0,0 +1,257 @@
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.transforms as transforms
import os
import torch.quantization
# Setup warnings
import warnings
warnings.filterwarnings(
action='ignore',
category=DeprecationWarning,
module=r'.*'
)
warnings.filterwarnings(
action='default',
module=r'torch.quantization'
)
"""
Define helper functions
"""
# Specify random seed for repeatable results
_ = torch.manual_seed(191009)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def evaluate(model, criterion, data_loader):
model.eval()
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
cnt = 0
with torch.no_grad():
for image, target in data_loader:
output = model(image)
loss = criterion(output, target)
cnt += 1
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
print('')
return top1, top5
def load_model(model_file):
model = resnet18(pretrained=False)
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.to("cpu")
return model
def print_size_of_model(model):
if isinstance(model, torch.jit.RecursiveScriptModule):
torch.jit.save(model, "temp.p")
else:
torch.jit.save(torch.jit.script(model), "temp.p")
print("Size (MB):", os.path.getsize("temp.p") / 1e6)
os.remove("temp.p")
def prepare_data_loaders(data_path):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(data_path,
split="train",
transform=transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize]))
dataset_test = torchvision.datasets.ImageNet(data_path,
split="val",
transform=transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize]))
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=train_batch_size,
sampler=train_sampler)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=eval_batch_size,
sampler=test_sampler)
return data_loader, data_loader_test
data_path = '~/my_imagenet/'
saved_model_dir = '/data/home/amandaliu/cluster/pytorch/test/quantization/core/experimental/data/'
float_model_file = 'resnet18_pretrained_float.pth'
train_batch_size = 30
eval_batch_size = 50
data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
float_model.eval()
# deepcopy the model since we need to keep the original model around
import copy
model_to_quantize = copy.deepcopy(float_model)
model_to_quantize.eval()
"""
Prepare models
"""
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
from torch.quantization.quantize_fx import prepare_fx, convert_fx
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
from torch.ao.quantization.experimental.qconfig import (
uniform_qconfig_8bit,
apot_weights_qconfig_8bit,
apot_qconfig_8bit,
uniform_qconfig_4bit,
apot_weights_qconfig_4bit,
apot_qconfig_4bit
)
"""
Prepare full precision model
"""
full_precision_model = float_model
top1, top5 = evaluate(full_precision_model, criterion, data_loader_test)
print("Model #0 Evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
"""
Prepare model PTQ for specified qconfig for torch.nn.Linear
"""
def prepare_ptq_linear(qconfig):
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
calibrate(prepared_model, data_loader_test) # run calibration on sample data
return prepared_model
"""
Prepare model with uniform activation, uniform weight
b=8, k=2
"""
prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
"""
Prepare model with uniform activation, uniform weight
b=4, k=2
"""
prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
"""
Prepare model with uniform activation, APoT weight
(b=8, k=2)
"""
prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit)
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
"""
Prepare model with uniform activation, APoT weight
(b=4, k=2)
"""
prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit)
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
"""
Prepare model with APoT activation and weight
(b=8, k=2)
"""
prepared_model = prepare_ptq_linear(apot_qconfig_8bit)
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
"""
Prepare model with APoT activation and weight
(b=4, k=2)
"""
prepared_model = prepare_ptq_linear(apot_qconfig_4bit)
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
"""
Prepare eager mode quantized model
"""
from torchvision.models.quantization.resnet import resnet18
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
print("Eager mode quantized model evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))

View File

@ -35,7 +35,7 @@ class TestFakeQuantize(unittest.TestCase):
r""" Tests fake quantize forward() method r""" Tests fake quantize forward() method
by comparing result with expected by comparing result with expected
float_to_reduced_precision mapping of input tensor. quant_dequant_APoT mapping of input tensor.
Uses input tensor with random values from 0 -> 1000 Uses input tensor with random values from 0 -> 1000
and APoT observer with hard-coded values b=4, k=2 and APoT observer with hard-coded values b=4, k=2
""" """

View File

@ -33,7 +33,7 @@ r"""Converts floating point input into
reduced precision floating point value reduced precision floating point value
based on quantization levels based on quantization levels
""" """
def float_to_reduced_precision(x, levels, indices): def quant_dequant_util(x, levels, indices):
levels_lst = list(levels) levels_lst = list(levels)
indices_lst = list(indices) indices_lst = list(indices)

View File

@ -10,18 +10,23 @@ class APoTFakeQuantize(FakeQuantizeBase):
quantization_levels: Tensor quantization_levels: Tensor
level_indices: Tensor level_indices: Tensor
def __init__(self, **observer_kwargs): def __init__(self, observer=APoTObserver, **observer_kwargs):
super().__init__() super().__init__()
self.activation_post_process = APoTObserver(**observer_kwargs) self.activation_post_process = observer(**observer_kwargs)
self.dtype = self.activation_post_process.dtype
def calculate_qparams(self, signed: bool): # type: ignore[override] def calculate_qparams(self, signed=False): # type: ignore[override]
return self.activation_post_process.calculate_qparams(signed=signed) return self.activation_post_process.calculate_qparams(signed=signed)
def forward(self, X: torch.Tensor, signed: bool): # type: ignore[override] def forward(self, X: torch.Tensor): # type: ignore[override]
if self.observer_enabled[0] == 1: if self.observer_enabled[0] == 1:
self.activation_post_process.forward(X) self.activation_post_process.forward(X)
self.alpha, self.gamma, self.quantization_levels, self.level_indices = \ result = self.activation_post_process.calculate_qparams(signed=False)
self.activation_post_process.calculate_qparams(signed) self.alpha = result[0]
self.gamma = result[1]
self.quantization_levels = result[2]
self.level_indices = result[3]
if self.fake_quant_enabled[0] == 1: if self.fake_quant_enabled[0] == 1:
assert (self.alpha is not None assert (self.alpha is not None
and self.gamma is not None and self.gamma is not None

View File

@ -23,7 +23,7 @@ class APoTObserver(ObserverBase):
self, self,
b, b,
k, k,
dtype=torch.int32) -> None: dtype=torch.quint8) -> None:
super().__init__(dtype) super().__init__(dtype)
self.b = b self.b = b
self.k = k self.k = k
@ -47,7 +47,7 @@ class APoTObserver(ObserverBase):
quantization_levels: non-uniform quantization levels (fp representation) quantization_levels: non-uniform quantization levels (fp representation)
level_indices: int representation of quantization_levels indices level_indices: int representation of quantization_levels indices
""" """
def _calculate_qparams(self, signed, min_val=None, max_val=None): def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
if min_val is not None: if min_val is not None:
self.min_val = min_val self.min_val = min_val
if max_val is not None: if max_val is not None:

View File

@ -0,0 +1,46 @@
import torch
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
"""
Default symmetric fake_quant for activations.
"""
default_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
qscheme=torch.per_tensor_symmetric,
dtype=torch.quint8)
"""
Default symmetric fake_quant for weights.
"""
default_weight_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
qscheme=torch.per_tensor_symmetric,
dtype=torch.qint8)
# uniform activation and weight, b=8 k=2
uniform_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant,
weight=default_weight_symmetric_fake_quant.with_args)
# uniform activation, APoT weight, b=8 k=2
apot_weight_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant.with_args,
weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
# APoT activation and uniform weight, b=8 k=2
apot_qconfig_8bit = QConfig(activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8),
weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
# uniform activation and weight, b=4 k=2
uniform_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
quant_max=15),
weight=default_weight_symmetric_fake_quant.with_args(quant_min=0,
quant_max=15))
# uniform activation, APoT weight, b=4 k=2
apot_weight_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
quant_max=15),
weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
# APoT activation and uniform weight, b=4 k=2
apot_qconfig_4bit = QConfig(activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8),
weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
import numpy as np import numpy as np
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util
# class to store APoT quantizer and # class to store APoT quantizer and
# implement quantize and dequantize # implement quantize and dequantize
@ -52,9 +52,9 @@ class APoTQuantizer():
based on the calculated quantization levels from a specified APoT non-uniform observer. based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf. The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args: Args:
apot_tensor: quantized APoT Tensor to dequantize tensor2quantize: fp Tensor
Returns: Returns:
result: fp representation of input Tensor result: fp reduced precision representation of input Tensor
""" """
def dequantize(self, apot_tensor) -> Tensor: def dequantize(self, apot_tensor) -> Tensor:
orig_size = apot_tensor.data.size() orig_size = apot_tensor.data.size()
@ -72,6 +72,21 @@ class APoTQuantizer():
return result return result
r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
apot_tensor: quantized APoT Tensor to dequantize
Returns:
result: fp representation of input Tensor
"""
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
levels_lst = list(self.quantization_levels)
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
return result
def q_apot_alpha(self) -> float: def q_apot_alpha(self) -> float:
raise NotImplementedError raise NotImplementedError
@ -100,3 +115,22 @@ def dequantize_APoT(apot_tensor) -> Tensor:
quantizer = apot_tensor.quantizer quantizer = apot_tensor.quantizer
result = quantizer.dequantize(apot_tensor) result = quantizer.dequantize(apot_tensor)
return result return result
r""" Global method to create quantizer and call quantizer quant_dequant
Args:
tensor2quantize: fp Tensor to quantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: fp reduced precision Tensor from tensor2quantize
"""
def quant_dequant_APoT(tensor2quantize: Tensor,
alpha: Tensor,
gamma: Tensor,
quantization_levels: Tensor,
level_indices: Tensor) -> Tensor:
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = quantizer.quant_dequant(tensor2quantize)
return result