mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
### Summary: This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight. According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below). ### Test Plan: Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py` ### Accuracy Stats: 8-bit (Uniform int8, APoT b = 8 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5) 4-bit (Uniform int4, APoT b = 4 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5) **Full Precision model (FX Graph Mode quantized)** Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5) **Eager mode quantized model** Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5) Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040 Approved by: https://github.com/jerryzh168
148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
"""
|
|
This module implements nonuniform observers used to collect statistics about
|
|
the values observed during calibration (PTQ) or training (QAT).
|
|
"""
|
|
|
|
import torch
|
|
import itertools
|
|
import matplotlib.pyplot as plt
|
|
from torch.ao.quantization.observer import ObserverBase
|
|
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
|
|
|
# TODO: Consider adding NonUniformQuantizationObserverBase class
|
|
# when more than one non-uniform method is implemented
|
|
|
|
class APoTObserver(ObserverBase):
|
|
b: int
|
|
k: int
|
|
n: int
|
|
min_val: torch.Tensor
|
|
max_val: torch.Tensor
|
|
|
|
def __init__(
|
|
self,
|
|
b,
|
|
k,
|
|
dtype=torch.quint8) -> None:
|
|
super().__init__(dtype)
|
|
self.b = b
|
|
self.k = k
|
|
|
|
self.min_val = torch.tensor([])
|
|
self.max_val = torch.tensor([])
|
|
|
|
# min_val and max_val are optional args to override
|
|
# the min_val and max_val observed by forward
|
|
def calculate_qparams(self, signed):
|
|
return self._calculate_qparams(signed, self.min_val, self.max_val)
|
|
|
|
r""" Calculates nonuniform quantization parameters according to APoT paper:
|
|
https://arxiv.org/pdf/1909.13144.pdf.
|
|
Arg:
|
|
signed: specifies whether to include signed values in quantization level calculations
|
|
min_val: optional arg that can override min_val internal attribute
|
|
max_val: optional arg that can override max_val internal attribute
|
|
Returns:
|
|
gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
|
|
quantization_levels: non-uniform quantization levels (fp representation)
|
|
level_indices: int representation of quantization_levels indices
|
|
"""
|
|
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
|
|
if min_val is not None:
|
|
self.min_val = min_val
|
|
if max_val is not None:
|
|
self.max_val = max_val
|
|
|
|
# compute alpha
|
|
alpha = torch.max(-self.min_val, self.max_val)
|
|
|
|
# check for valid inputs of b, k
|
|
assert(self.k and self.k != 0)
|
|
assert(self.b % self.k == 0)
|
|
|
|
# compute n and store as member variable
|
|
self.n = self.b // self.k
|
|
|
|
# store a tensor of subtensors (all levels)
|
|
p_all = []
|
|
|
|
# create levels
|
|
for i in range(0, self.n):
|
|
p_curr = torch.tensor([0])
|
|
|
|
for j in range(0, (2 ** self.k - 2) + 1):
|
|
curr_ele = 2 ** (- (i + j * self.n))
|
|
p_append = torch.tensor([curr_ele])
|
|
p_curr = torch.cat((p_curr, p_append))
|
|
# introduce signed numbers
|
|
if signed:
|
|
p_curr = torch.cat((p_curr, torch.tensor([-curr_ele])))
|
|
|
|
if signed:
|
|
# sort tensor in reverse order before adding to list if signed
|
|
sorted, indices = torch.sort(p_curr, descending=True)
|
|
p_all.append(sorted)
|
|
else:
|
|
p_all.append(p_curr)
|
|
|
|
# gamma calculation:
|
|
# loop through all tensors
|
|
# if signed, add element at index 0 for each tensor
|
|
# else, add element at index 1 for each tensor
|
|
# gamma defined to ensure alpha is at max of range
|
|
p_sum = 0.0
|
|
for tens in p_all:
|
|
if signed:
|
|
p_sum += float(tens[0])
|
|
else:
|
|
p_sum += float(tens[1])
|
|
|
|
# assign gamma
|
|
gamma = alpha / p_sum
|
|
|
|
# calculate cartesian product
|
|
cartesian_product = list(itertools.product(*p_all))
|
|
|
|
quantization_levels_list = []
|
|
|
|
# calculate sum of each row
|
|
for row in cartesian_product:
|
|
sum = 0.0
|
|
for ele in row:
|
|
sum += ele
|
|
quantization_levels_list.append(sum)
|
|
|
|
quantization_levels_gamma = [float(gamma) * ele for ele in quantization_levels_list]
|
|
quantization_levels = torch.tensor(quantization_levels_gamma)
|
|
level_indices = torch.tensor([])
|
|
quantization_levels, level_indices = quantization_levels.sort()
|
|
|
|
return (alpha, gamma, quantization_levels, level_indices)
|
|
|
|
r"""Records the running minimum and maximum of ``x``."""
|
|
def forward(self, x_orig):
|
|
if x_orig.numel() == 0:
|
|
return x_orig
|
|
x = x_orig.detach()
|
|
min_val, max_val = torch.aminmax(x)
|
|
if self.min_val.numel():
|
|
min_val = torch.min(min_val, self.min_val)
|
|
if self.max_val.numel():
|
|
max_val = torch.max(max_val, self.max_val)
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
return x_orig
|
|
|
|
def quant_levels_visualization(self, obs_result, filename):
|
|
xs = [float(x) / 1000.0 for x in range(1000)]
|
|
ys = [apot_to_float(float_to_apot(x, obs_result[1], obs_result[2]),
|
|
obs_result[1], obs_result[2]).item() for x in xs]
|
|
|
|
f = plt.figure(figsize=(15, 10))
|
|
|
|
plt.plot(xs, ys)
|
|
plt.title("APoT Quantization Plot")
|
|
plt.xlabel("Full Precision")
|
|
plt.ylabel("Quantized")
|
|
plt.show()
|