mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[quant] Implement PTQ for APoT FakeQuant (#81040)
### Summary: This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight. According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below). ### Test Plan: Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py` ### Accuracy Stats: 8-bit (Uniform int8, APoT b = 8 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5) 4-bit (Uniform int4, APoT b = 4 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5) **Full Precision model (FX Graph Mode quantized)** Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5) **Eager mode quantized model** Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5) Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
f445c220be
commit
13ad4739a6
3
mypy.ini
3
mypy.ini
|
|
@ -73,6 +73,9 @@ ignore_missing_imports = True
|
|||
[mypy-torch.ao.quantization.experimental.fake_quantize_function]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.ao.quantization.experimental.fake_quantize]
|
||||
ignore_missing_imports = True
|
||||
|
||||
#
|
||||
# Files with various errors. Mostly real errors, possibly some false
|
||||
# positives as well.
|
||||
|
|
|
|||
Binary file not shown.
257
test/quantization/core/experimental/fx_graph_mode_apot.py
Normal file
257
test/quantization/core/experimental/fx_graph_mode_apot.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torchvision.transforms.transforms as transforms
|
||||
import os
|
||||
import torch.quantization
|
||||
|
||||
# Setup warnings
|
||||
import warnings
|
||||
warnings.filterwarnings(
|
||||
action='ignore',
|
||||
category=DeprecationWarning,
|
||||
module=r'.*'
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
action='default',
|
||||
module=r'torch.quantization'
|
||||
)
|
||||
|
||||
"""
|
||||
Define helper functions
|
||||
"""
|
||||
|
||||
# Specify random seed for repeatable results
|
||||
_ = torch.manual_seed(191009)
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self, name, fmt=':f'):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def evaluate(model, criterion, data_loader):
|
||||
model.eval()
|
||||
top1 = AverageMeter('Acc@1', ':6.2f')
|
||||
top5 = AverageMeter('Acc@5', ':6.2f')
|
||||
cnt = 0
|
||||
with torch.no_grad():
|
||||
for image, target in data_loader:
|
||||
output = model(image)
|
||||
loss = criterion(output, target)
|
||||
cnt += 1
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
top1.update(acc1[0], image.size(0))
|
||||
top5.update(acc5[0], image.size(0))
|
||||
print('')
|
||||
|
||||
return top1, top5
|
||||
|
||||
def load_model(model_file):
|
||||
model = resnet18(pretrained=False)
|
||||
state_dict = torch.load(model_file)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to("cpu")
|
||||
return model
|
||||
|
||||
def print_size_of_model(model):
|
||||
if isinstance(model, torch.jit.RecursiveScriptModule):
|
||||
torch.jit.save(model, "temp.p")
|
||||
else:
|
||||
torch.jit.save(torch.jit.script(model), "temp.p")
|
||||
print("Size (MB):", os.path.getsize("temp.p") / 1e6)
|
||||
os.remove("temp.p")
|
||||
|
||||
def prepare_data_loaders(data_path):
|
||||
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
dataset = torchvision.datasets.ImageNet(data_path,
|
||||
split="train",
|
||||
transform=transforms.Compose([transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize]))
|
||||
dataset_test = torchvision.datasets.ImageNet(data_path,
|
||||
split="val",
|
||||
transform=transforms.Compose([transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize]))
|
||||
|
||||
train_sampler = torch.utils.data.RandomSampler(dataset)
|
||||
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=train_batch_size,
|
||||
sampler=train_sampler)
|
||||
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test, batch_size=eval_batch_size,
|
||||
sampler=test_sampler)
|
||||
|
||||
return data_loader, data_loader_test
|
||||
|
||||
data_path = '~/my_imagenet/'
|
||||
saved_model_dir = '/data/home/amandaliu/cluster/pytorch/test/quantization/core/experimental/data/'
|
||||
float_model_file = 'resnet18_pretrained_float.pth'
|
||||
|
||||
train_batch_size = 30
|
||||
eval_batch_size = 50
|
||||
|
||||
data_loader, data_loader_test = prepare_data_loaders(data_path)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
|
||||
float_model.eval()
|
||||
|
||||
# deepcopy the model since we need to keep the original model around
|
||||
import copy
|
||||
model_to_quantize = copy.deepcopy(float_model)
|
||||
|
||||
model_to_quantize.eval()
|
||||
|
||||
"""
|
||||
Prepare models
|
||||
"""
|
||||
|
||||
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
|
||||
from torch.quantization.quantize_fx import prepare_fx, convert_fx
|
||||
|
||||
def calibrate(model, data_loader):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for image, target in data_loader:
|
||||
model(image)
|
||||
|
||||
from torch.ao.quantization.experimental.qconfig import (
|
||||
uniform_qconfig_8bit,
|
||||
apot_weights_qconfig_8bit,
|
||||
apot_qconfig_8bit,
|
||||
uniform_qconfig_4bit,
|
||||
apot_weights_qconfig_4bit,
|
||||
apot_qconfig_4bit
|
||||
)
|
||||
|
||||
"""
|
||||
Prepare full precision model
|
||||
"""
|
||||
full_precision_model = float_model
|
||||
|
||||
top1, top5 = evaluate(full_precision_model, criterion, data_loader_test)
|
||||
print("Model #0 Evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||
|
||||
"""
|
||||
Prepare model PTQ for specified qconfig for torch.nn.Linear
|
||||
"""
|
||||
def prepare_ptq_linear(qconfig):
|
||||
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
|
||||
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
|
||||
calibrate(prepared_model, data_loader_test) # run calibration on sample data
|
||||
return prepared_model
|
||||
|
||||
"""
|
||||
Prepare model with uniform activation, uniform weight
|
||||
b=8, k=2
|
||||
"""
|
||||
|
||||
prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
|
||||
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
|
||||
|
||||
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
|
||||
print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||
|
||||
"""
|
||||
Prepare model with uniform activation, uniform weight
|
||||
b=4, k=2
|
||||
"""
|
||||
|
||||
prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
|
||||
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
|
||||
|
||||
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
|
||||
print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||
|
||||
"""
|
||||
Prepare model with uniform activation, APoT weight
|
||||
(b=8, k=2)
|
||||
"""
|
||||
|
||||
prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit)
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||
|
||||
"""
|
||||
Prepare model with uniform activation, APoT weight
|
||||
(b=4, k=2)
|
||||
"""
|
||||
|
||||
prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit)
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||
|
||||
|
||||
"""
|
||||
Prepare model with APoT activation and weight
|
||||
(b=8, k=2)
|
||||
"""
|
||||
|
||||
prepared_model = prepare_ptq_linear(apot_qconfig_8bit)
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||
|
||||
"""
|
||||
Prepare model with APoT activation and weight
|
||||
(b=4, k=2)
|
||||
"""
|
||||
|
||||
prepared_model = prepare_ptq_linear(apot_qconfig_4bit)
|
||||
|
||||
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||
print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||
|
||||
"""
|
||||
Prepare eager mode quantized model
|
||||
"""
|
||||
|
||||
from torchvision.models.quantization.resnet import resnet18
|
||||
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
|
||||
top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
|
||||
print("Eager mode quantized model evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||
|
|
@ -35,7 +35,7 @@ class TestFakeQuantize(unittest.TestCase):
|
|||
|
||||
r""" Tests fake quantize forward() method
|
||||
by comparing result with expected
|
||||
float_to_reduced_precision mapping of input tensor.
|
||||
quant_dequant_APoT mapping of input tensor.
|
||||
Uses input tensor with random values from 0 -> 1000
|
||||
and APoT observer with hard-coded values b=4, k=2
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ r"""Converts floating point input into
|
|||
reduced precision floating point value
|
||||
based on quantization levels
|
||||
"""
|
||||
def float_to_reduced_precision(x, levels, indices):
|
||||
def quant_dequant_util(x, levels, indices):
|
||||
levels_lst = list(levels)
|
||||
indices_lst = list(indices)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,18 +10,23 @@ class APoTFakeQuantize(FakeQuantizeBase):
|
|||
quantization_levels: Tensor
|
||||
level_indices: Tensor
|
||||
|
||||
def __init__(self, **observer_kwargs):
|
||||
def __init__(self, observer=APoTObserver, **observer_kwargs):
|
||||
super().__init__()
|
||||
self.activation_post_process = APoTObserver(**observer_kwargs)
|
||||
self.activation_post_process = observer(**observer_kwargs)
|
||||
self.dtype = self.activation_post_process.dtype
|
||||
|
||||
def calculate_qparams(self, signed: bool): # type: ignore[override]
|
||||
def calculate_qparams(self, signed=False): # type: ignore[override]
|
||||
return self.activation_post_process.calculate_qparams(signed=signed)
|
||||
|
||||
def forward(self, X: torch.Tensor, signed: bool): # type: ignore[override]
|
||||
def forward(self, X: torch.Tensor): # type: ignore[override]
|
||||
if self.observer_enabled[0] == 1:
|
||||
self.activation_post_process.forward(X)
|
||||
self.alpha, self.gamma, self.quantization_levels, self.level_indices = \
|
||||
self.activation_post_process.calculate_qparams(signed)
|
||||
result = self.activation_post_process.calculate_qparams(signed=False)
|
||||
self.alpha = result[0]
|
||||
self.gamma = result[1]
|
||||
self.quantization_levels = result[2]
|
||||
self.level_indices = result[3]
|
||||
|
||||
if self.fake_quant_enabled[0] == 1:
|
||||
assert (self.alpha is not None
|
||||
and self.gamma is not None
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class APoTObserver(ObserverBase):
|
|||
self,
|
||||
b,
|
||||
k,
|
||||
dtype=torch.int32) -> None:
|
||||
dtype=torch.quint8) -> None:
|
||||
super().__init__(dtype)
|
||||
self.b = b
|
||||
self.k = k
|
||||
|
|
@ -47,7 +47,7 @@ class APoTObserver(ObserverBase):
|
|||
quantization_levels: non-uniform quantization levels (fp representation)
|
||||
level_indices: int representation of quantization_levels indices
|
||||
"""
|
||||
def _calculate_qparams(self, signed, min_val=None, max_val=None):
|
||||
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
|
||||
if min_val is not None:
|
||||
self.min_val = min_val
|
||||
if max_val is not None:
|
||||
|
|
|
|||
46
torch/ao/quantization/experimental/qconfig.py
Normal file
46
torch/ao/quantization/experimental/qconfig.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
from torch.ao.quantization import MinMaxObserver
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
|
||||
|
||||
"""
|
||||
Default symmetric fake_quant for activations.
|
||||
"""
|
||||
default_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
|
||||
qscheme=torch.per_tensor_symmetric,
|
||||
dtype=torch.quint8)
|
||||
|
||||
"""
|
||||
Default symmetric fake_quant for weights.
|
||||
"""
|
||||
default_weight_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
|
||||
qscheme=torch.per_tensor_symmetric,
|
||||
dtype=torch.qint8)
|
||||
|
||||
# uniform activation and weight, b=8 k=2
|
||||
uniform_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant,
|
||||
weight=default_weight_symmetric_fake_quant.with_args)
|
||||
|
||||
# uniform activation, APoT weight, b=8 k=2
|
||||
apot_weight_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant.with_args,
|
||||
weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
|
||||
|
||||
# APoT activation and uniform weight, b=8 k=2
|
||||
apot_qconfig_8bit = QConfig(activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8),
|
||||
weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
|
||||
|
||||
# uniform activation and weight, b=4 k=2
|
||||
uniform_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
|
||||
quant_max=15),
|
||||
weight=default_weight_symmetric_fake_quant.with_args(quant_min=0,
|
||||
quant_max=15))
|
||||
|
||||
# uniform activation, APoT weight, b=4 k=2
|
||||
apot_weight_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
|
||||
quant_max=15),
|
||||
weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
|
||||
|
||||
# APoT activation and uniform weight, b=4 k=2
|
||||
apot_qconfig_4bit = QConfig(activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8),
|
||||
weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
import numpy as np
|
||||
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
||||
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util
|
||||
|
||||
# class to store APoT quantizer and
|
||||
# implement quantize and dequantize
|
||||
|
|
@ -52,9 +52,9 @@ class APoTQuantizer():
|
|||
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
||||
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
||||
Args:
|
||||
apot_tensor: quantized APoT Tensor to dequantize
|
||||
tensor2quantize: fp Tensor
|
||||
Returns:
|
||||
result: fp representation of input Tensor
|
||||
result: fp reduced precision representation of input Tensor
|
||||
"""
|
||||
def dequantize(self, apot_tensor) -> Tensor:
|
||||
orig_size = apot_tensor.data.size()
|
||||
|
|
@ -72,6 +72,21 @@ class APoTQuantizer():
|
|||
|
||||
return result
|
||||
|
||||
r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
|
||||
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
||||
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
||||
Args:
|
||||
apot_tensor: quantized APoT Tensor to dequantize
|
||||
Returns:
|
||||
result: fp representation of input Tensor
|
||||
"""
|
||||
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
|
||||
levels_lst = list(self.quantization_levels)
|
||||
|
||||
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
|
||||
|
||||
return result
|
||||
|
||||
def q_apot_alpha(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -100,3 +115,22 @@ def dequantize_APoT(apot_tensor) -> Tensor:
|
|||
quantizer = apot_tensor.quantizer
|
||||
result = quantizer.dequantize(apot_tensor)
|
||||
return result
|
||||
|
||||
r""" Global method to create quantizer and call quantizer quant_dequant
|
||||
Args:
|
||||
tensor2quantize: fp Tensor to quantize
|
||||
alpha: Tensor qparam alpha (clipping level)
|
||||
gamma: Tensor qparam gamma (scale factor for quantization levels)
|
||||
quantization levels: Tensor with fp quantization levels
|
||||
level indices: Tensor with integer quantization level indices
|
||||
Returns:
|
||||
result: fp reduced precision Tensor from tensor2quantize
|
||||
"""
|
||||
def quant_dequant_APoT(tensor2quantize: Tensor,
|
||||
alpha: Tensor,
|
||||
gamma: Tensor,
|
||||
quantization_levels: Tensor,
|
||||
level_indices: Tensor) -> Tensor:
|
||||
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
|
||||
result = quantizer.quant_dequant(tensor2quantize)
|
||||
return result
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user