mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant] Implement PTQ for APoT FakeQuant (#81040)
### Summary: This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight. According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below). ### Test Plan: Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py` ### Accuracy Stats: 8-bit (Uniform int8, APoT b = 8 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5) 4-bit (Uniform int4, APoT b = 4 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5) **Full Precision model (FX Graph Mode quantized)** Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5) **Eager mode quantized model** Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5) Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
f445c220be
commit
13ad4739a6
3
mypy.ini
3
mypy.ini
|
|
@ -73,6 +73,9 @@ ignore_missing_imports = True
|
||||||
[mypy-torch.ao.quantization.experimental.fake_quantize_function]
|
[mypy-torch.ao.quantization.experimental.fake_quantize_function]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-torch.ao.quantization.experimental.fake_quantize]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
#
|
#
|
||||||
# Files with various errors. Mostly real errors, possibly some false
|
# Files with various errors. Mostly real errors, possibly some false
|
||||||
# positives as well.
|
# positives as well.
|
||||||
|
|
|
||||||
Binary file not shown.
257
test/quantization/core/experimental/fx_graph_mode_apot.py
Normal file
257
test/quantization/core/experimental/fx_graph_mode_apot.py
Normal file
|
|
@ -0,0 +1,257 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms.transforms as transforms
|
||||||
|
import os
|
||||||
|
import torch.quantization
|
||||||
|
|
||||||
|
# Setup warnings
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings(
|
||||||
|
action='ignore',
|
||||||
|
category=DeprecationWarning,
|
||||||
|
module=r'.*'
|
||||||
|
)
|
||||||
|
warnings.filterwarnings(
|
||||||
|
action='default',
|
||||||
|
module=r'torch.quantization'
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Define helper functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Specify random seed for repeatable results
|
||||||
|
_ = torch.manual_seed(191009)
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
def __init__(self, name, fmt=':f'):
|
||||||
|
self.name = name
|
||||||
|
self.fmt = fmt
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||||
|
return fmtstr.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(output, target, topk=(1,)):
|
||||||
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||||
|
with torch.no_grad():
|
||||||
|
maxk = max(topk)
|
||||||
|
batch_size = target.size(0)
|
||||||
|
|
||||||
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
|
pred = pred.t()
|
||||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in topk:
|
||||||
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||||
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, criterion, data_loader):
|
||||||
|
model.eval()
|
||||||
|
top1 = AverageMeter('Acc@1', ':6.2f')
|
||||||
|
top5 = AverageMeter('Acc@5', ':6.2f')
|
||||||
|
cnt = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for image, target in data_loader:
|
||||||
|
output = model(image)
|
||||||
|
loss = criterion(output, target)
|
||||||
|
cnt += 1
|
||||||
|
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||||
|
top1.update(acc1[0], image.size(0))
|
||||||
|
top5.update(acc5[0], image.size(0))
|
||||||
|
print('')
|
||||||
|
|
||||||
|
return top1, top5
|
||||||
|
|
||||||
|
def load_model(model_file):
|
||||||
|
model = resnet18(pretrained=False)
|
||||||
|
state_dict = torch.load(model_file)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.to("cpu")
|
||||||
|
return model
|
||||||
|
|
||||||
|
def print_size_of_model(model):
|
||||||
|
if isinstance(model, torch.jit.RecursiveScriptModule):
|
||||||
|
torch.jit.save(model, "temp.p")
|
||||||
|
else:
|
||||||
|
torch.jit.save(torch.jit.script(model), "temp.p")
|
||||||
|
print("Size (MB):", os.path.getsize("temp.p") / 1e6)
|
||||||
|
os.remove("temp.p")
|
||||||
|
|
||||||
|
def prepare_data_loaders(data_path):
|
||||||
|
|
||||||
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])
|
||||||
|
dataset = torchvision.datasets.ImageNet(data_path,
|
||||||
|
split="train",
|
||||||
|
transform=transforms.Compose([transforms.RandomResizedCrop(224),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
normalize]))
|
||||||
|
dataset_test = torchvision.datasets.ImageNet(data_path,
|
||||||
|
split="val",
|
||||||
|
transform=transforms.Compose([transforms.Resize(256),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
normalize]))
|
||||||
|
|
||||||
|
train_sampler = torch.utils.data.RandomSampler(dataset)
|
||||||
|
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
||||||
|
|
||||||
|
data_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset, batch_size=train_batch_size,
|
||||||
|
sampler=train_sampler)
|
||||||
|
|
||||||
|
data_loader_test = torch.utils.data.DataLoader(
|
||||||
|
dataset_test, batch_size=eval_batch_size,
|
||||||
|
sampler=test_sampler)
|
||||||
|
|
||||||
|
return data_loader, data_loader_test
|
||||||
|
|
||||||
|
data_path = '~/my_imagenet/'
|
||||||
|
saved_model_dir = '/data/home/amandaliu/cluster/pytorch/test/quantization/core/experimental/data/'
|
||||||
|
float_model_file = 'resnet18_pretrained_float.pth'
|
||||||
|
|
||||||
|
train_batch_size = 30
|
||||||
|
eval_batch_size = 50
|
||||||
|
|
||||||
|
data_loader, data_loader_test = prepare_data_loaders(data_path)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
|
||||||
|
float_model.eval()
|
||||||
|
|
||||||
|
# deepcopy the model since we need to keep the original model around
|
||||||
|
import copy
|
||||||
|
model_to_quantize = copy.deepcopy(float_model)
|
||||||
|
|
||||||
|
model_to_quantize.eval()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare models
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
|
||||||
|
from torch.quantization.quantize_fx import prepare_fx, convert_fx
|
||||||
|
|
||||||
|
def calibrate(model, data_loader):
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for image, target in data_loader:
|
||||||
|
model(image)
|
||||||
|
|
||||||
|
from torch.ao.quantization.experimental.qconfig import (
|
||||||
|
uniform_qconfig_8bit,
|
||||||
|
apot_weights_qconfig_8bit,
|
||||||
|
apot_qconfig_8bit,
|
||||||
|
uniform_qconfig_4bit,
|
||||||
|
apot_weights_qconfig_4bit,
|
||||||
|
apot_qconfig_4bit
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare full precision model
|
||||||
|
"""
|
||||||
|
full_precision_model = float_model
|
||||||
|
|
||||||
|
top1, top5 = evaluate(full_precision_model, criterion, data_loader_test)
|
||||||
|
print("Model #0 Evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare model PTQ for specified qconfig for torch.nn.Linear
|
||||||
|
"""
|
||||||
|
def prepare_ptq_linear(qconfig):
|
||||||
|
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
|
||||||
|
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
|
||||||
|
calibrate(prepared_model, data_loader_test) # run calibration on sample data
|
||||||
|
return prepared_model
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare model with uniform activation, uniform weight
|
||||||
|
b=8, k=2
|
||||||
|
"""
|
||||||
|
|
||||||
|
prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
|
||||||
|
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
|
||||||
|
|
||||||
|
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
|
||||||
|
print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare model with uniform activation, uniform weight
|
||||||
|
b=4, k=2
|
||||||
|
"""
|
||||||
|
|
||||||
|
prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
|
||||||
|
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
|
||||||
|
|
||||||
|
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
|
||||||
|
print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare model with uniform activation, APoT weight
|
||||||
|
(b=8, k=2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit)
|
||||||
|
|
||||||
|
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||||
|
print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare model with uniform activation, APoT weight
|
||||||
|
(b=4, k=2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit)
|
||||||
|
|
||||||
|
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||||
|
print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare model with APoT activation and weight
|
||||||
|
(b=8, k=2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prepared_model = prepare_ptq_linear(apot_qconfig_8bit)
|
||||||
|
|
||||||
|
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||||
|
print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare model with APoT activation and weight
|
||||||
|
(b=4, k=2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prepared_model = prepare_ptq_linear(apot_qconfig_4bit)
|
||||||
|
|
||||||
|
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
|
||||||
|
print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Prepare eager mode quantized model
|
||||||
|
"""
|
||||||
|
|
||||||
|
from torchvision.models.quantization.resnet import resnet18
|
||||||
|
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
|
||||||
|
top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
|
||||||
|
print("Eager mode quantized model evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
|
||||||
|
|
@ -35,7 +35,7 @@ class TestFakeQuantize(unittest.TestCase):
|
||||||
|
|
||||||
r""" Tests fake quantize forward() method
|
r""" Tests fake quantize forward() method
|
||||||
by comparing result with expected
|
by comparing result with expected
|
||||||
float_to_reduced_precision mapping of input tensor.
|
quant_dequant_APoT mapping of input tensor.
|
||||||
Uses input tensor with random values from 0 -> 1000
|
Uses input tensor with random values from 0 -> 1000
|
||||||
and APoT observer with hard-coded values b=4, k=2
|
and APoT observer with hard-coded values b=4, k=2
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ r"""Converts floating point input into
|
||||||
reduced precision floating point value
|
reduced precision floating point value
|
||||||
based on quantization levels
|
based on quantization levels
|
||||||
"""
|
"""
|
||||||
def float_to_reduced_precision(x, levels, indices):
|
def quant_dequant_util(x, levels, indices):
|
||||||
levels_lst = list(levels)
|
levels_lst = list(levels)
|
||||||
indices_lst = list(indices)
|
indices_lst = list(indices)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,18 +10,23 @@ class APoTFakeQuantize(FakeQuantizeBase):
|
||||||
quantization_levels: Tensor
|
quantization_levels: Tensor
|
||||||
level_indices: Tensor
|
level_indices: Tensor
|
||||||
|
|
||||||
def __init__(self, **observer_kwargs):
|
def __init__(self, observer=APoTObserver, **observer_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activation_post_process = APoTObserver(**observer_kwargs)
|
self.activation_post_process = observer(**observer_kwargs)
|
||||||
|
self.dtype = self.activation_post_process.dtype
|
||||||
|
|
||||||
def calculate_qparams(self, signed: bool): # type: ignore[override]
|
def calculate_qparams(self, signed=False): # type: ignore[override]
|
||||||
return self.activation_post_process.calculate_qparams(signed=signed)
|
return self.activation_post_process.calculate_qparams(signed=signed)
|
||||||
|
|
||||||
def forward(self, X: torch.Tensor, signed: bool): # type: ignore[override]
|
def forward(self, X: torch.Tensor): # type: ignore[override]
|
||||||
if self.observer_enabled[0] == 1:
|
if self.observer_enabled[0] == 1:
|
||||||
self.activation_post_process.forward(X)
|
self.activation_post_process.forward(X)
|
||||||
self.alpha, self.gamma, self.quantization_levels, self.level_indices = \
|
result = self.activation_post_process.calculate_qparams(signed=False)
|
||||||
self.activation_post_process.calculate_qparams(signed)
|
self.alpha = result[0]
|
||||||
|
self.gamma = result[1]
|
||||||
|
self.quantization_levels = result[2]
|
||||||
|
self.level_indices = result[3]
|
||||||
|
|
||||||
if self.fake_quant_enabled[0] == 1:
|
if self.fake_quant_enabled[0] == 1:
|
||||||
assert (self.alpha is not None
|
assert (self.alpha is not None
|
||||||
and self.gamma is not None
|
and self.gamma is not None
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class APoTObserver(ObserverBase):
|
||||||
self,
|
self,
|
||||||
b,
|
b,
|
||||||
k,
|
k,
|
||||||
dtype=torch.int32) -> None:
|
dtype=torch.quint8) -> None:
|
||||||
super().__init__(dtype)
|
super().__init__(dtype)
|
||||||
self.b = b
|
self.b = b
|
||||||
self.k = k
|
self.k = k
|
||||||
|
|
@ -47,7 +47,7 @@ class APoTObserver(ObserverBase):
|
||||||
quantization_levels: non-uniform quantization levels (fp representation)
|
quantization_levels: non-uniform quantization levels (fp representation)
|
||||||
level_indices: int representation of quantization_levels indices
|
level_indices: int representation of quantization_levels indices
|
||||||
"""
|
"""
|
||||||
def _calculate_qparams(self, signed, min_val=None, max_val=None):
|
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
|
||||||
if min_val is not None:
|
if min_val is not None:
|
||||||
self.min_val = min_val
|
self.min_val = min_val
|
||||||
if max_val is not None:
|
if max_val is not None:
|
||||||
|
|
|
||||||
46
torch/ao/quantization/experimental/qconfig.py
Normal file
46
torch/ao/quantization/experimental/qconfig.py
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
import torch
|
||||||
|
from torch.ao.quantization.qconfig import QConfig
|
||||||
|
from torch.ao.quantization import MinMaxObserver
|
||||||
|
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||||
|
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
|
||||||
|
|
||||||
|
"""
|
||||||
|
Default symmetric fake_quant for activations.
|
||||||
|
"""
|
||||||
|
default_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
|
||||||
|
qscheme=torch.per_tensor_symmetric,
|
||||||
|
dtype=torch.quint8)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Default symmetric fake_quant for weights.
|
||||||
|
"""
|
||||||
|
default_weight_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
|
||||||
|
qscheme=torch.per_tensor_symmetric,
|
||||||
|
dtype=torch.qint8)
|
||||||
|
|
||||||
|
# uniform activation and weight, b=8 k=2
|
||||||
|
uniform_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant,
|
||||||
|
weight=default_weight_symmetric_fake_quant.with_args)
|
||||||
|
|
||||||
|
# uniform activation, APoT weight, b=8 k=2
|
||||||
|
apot_weight_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant.with_args,
|
||||||
|
weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
|
||||||
|
|
||||||
|
# APoT activation and uniform weight, b=8 k=2
|
||||||
|
apot_qconfig_8bit = QConfig(activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8),
|
||||||
|
weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
|
||||||
|
|
||||||
|
# uniform activation and weight, b=4 k=2
|
||||||
|
uniform_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
|
||||||
|
quant_max=15),
|
||||||
|
weight=default_weight_symmetric_fake_quant.with_args(quant_min=0,
|
||||||
|
quant_max=15))
|
||||||
|
|
||||||
|
# uniform activation, APoT weight, b=4 k=2
|
||||||
|
apot_weight_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
|
||||||
|
quant_max=15),
|
||||||
|
weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
|
||||||
|
|
||||||
|
# APoT activation and uniform weight, b=4 k=2
|
||||||
|
apot_qconfig_4bit = QConfig(activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8),
|
||||||
|
weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util
|
||||||
|
|
||||||
# class to store APoT quantizer and
|
# class to store APoT quantizer and
|
||||||
# implement quantize and dequantize
|
# implement quantize and dequantize
|
||||||
|
|
@ -52,9 +52,9 @@ class APoTQuantizer():
|
||||||
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
||||||
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
||||||
Args:
|
Args:
|
||||||
apot_tensor: quantized APoT Tensor to dequantize
|
tensor2quantize: fp Tensor
|
||||||
Returns:
|
Returns:
|
||||||
result: fp representation of input Tensor
|
result: fp reduced precision representation of input Tensor
|
||||||
"""
|
"""
|
||||||
def dequantize(self, apot_tensor) -> Tensor:
|
def dequantize(self, apot_tensor) -> Tensor:
|
||||||
orig_size = apot_tensor.data.size()
|
orig_size = apot_tensor.data.size()
|
||||||
|
|
@ -72,6 +72,21 @@ class APoTQuantizer():
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
|
||||||
|
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
||||||
|
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
||||||
|
Args:
|
||||||
|
apot_tensor: quantized APoT Tensor to dequantize
|
||||||
|
Returns:
|
||||||
|
result: fp representation of input Tensor
|
||||||
|
"""
|
||||||
|
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
|
||||||
|
levels_lst = list(self.quantization_levels)
|
||||||
|
|
||||||
|
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def q_apot_alpha(self) -> float:
|
def q_apot_alpha(self) -> float:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -100,3 +115,22 @@ def dequantize_APoT(apot_tensor) -> Tensor:
|
||||||
quantizer = apot_tensor.quantizer
|
quantizer = apot_tensor.quantizer
|
||||||
result = quantizer.dequantize(apot_tensor)
|
result = quantizer.dequantize(apot_tensor)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
r""" Global method to create quantizer and call quantizer quant_dequant
|
||||||
|
Args:
|
||||||
|
tensor2quantize: fp Tensor to quantize
|
||||||
|
alpha: Tensor qparam alpha (clipping level)
|
||||||
|
gamma: Tensor qparam gamma (scale factor for quantization levels)
|
||||||
|
quantization levels: Tensor with fp quantization levels
|
||||||
|
level indices: Tensor with integer quantization level indices
|
||||||
|
Returns:
|
||||||
|
result: fp reduced precision Tensor from tensor2quantize
|
||||||
|
"""
|
||||||
|
def quant_dequant_APoT(tensor2quantize: Tensor,
|
||||||
|
alpha: Tensor,
|
||||||
|
gamma: Tensor,
|
||||||
|
quantization_levels: Tensor,
|
||||||
|
level_indices: Tensor) -> Tensor:
|
||||||
|
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
|
||||||
|
result = quantizer.quant_dequant(tensor2quantize)
|
||||||
|
return result
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user