pytorch/torch/ao/quantization/experimental/observer.py
asl3 13ad4739a6 [quant] Implement PTQ for APoT FakeQuant (#81040)
### Summary:
This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight.

According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below).

### Test Plan:
Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py`

### Accuracy Stats:
8-bit (Uniform int8, APoT b = 8 k = 2)

**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5)

**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5)

**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5)

4-bit (Uniform int4, APoT b = 4 k = 2)

**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5)

**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5)

**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5)

**Full Precision model (FX Graph Mode quantized)**
Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5)

**Eager mode quantized model**
Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040
Approved by: https://github.com/jerryzh168
2022-07-28 07:21:31 +00:00

148 lines
5.0 KiB
Python

"""
This module implements nonuniform observers used to collect statistics about
the values observed during calibration (PTQ) or training (QAT).
"""
import torch
import itertools
import matplotlib.pyplot as plt
from torch.ao.quantization.observer import ObserverBase
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
# TODO: Consider adding NonUniformQuantizationObserverBase class
# when more than one non-uniform method is implemented
class APoTObserver(ObserverBase):
b: int
k: int
n: int
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(
self,
b,
k,
dtype=torch.quint8) -> None:
super().__init__(dtype)
self.b = b
self.k = k
self.min_val = torch.tensor([])
self.max_val = torch.tensor([])
# min_val and max_val are optional args to override
# the min_val and max_val observed by forward
def calculate_qparams(self, signed):
return self._calculate_qparams(signed, self.min_val, self.max_val)
r""" Calculates nonuniform quantization parameters according to APoT paper:
https://arxiv.org/pdf/1909.13144.pdf.
Arg:
signed: specifies whether to include signed values in quantization level calculations
min_val: optional arg that can override min_val internal attribute
max_val: optional arg that can override max_val internal attribute
Returns:
gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
quantization_levels: non-uniform quantization levels (fp representation)
level_indices: int representation of quantization_levels indices
"""
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
if min_val is not None:
self.min_val = min_val
if max_val is not None:
self.max_val = max_val
# compute alpha
alpha = torch.max(-self.min_val, self.max_val)
# check for valid inputs of b, k
assert(self.k and self.k != 0)
assert(self.b % self.k == 0)
# compute n and store as member variable
self.n = self.b // self.k
# store a tensor of subtensors (all levels)
p_all = []
# create levels
for i in range(0, self.n):
p_curr = torch.tensor([0])
for j in range(0, (2 ** self.k - 2) + 1):
curr_ele = 2 ** (- (i + j * self.n))
p_append = torch.tensor([curr_ele])
p_curr = torch.cat((p_curr, p_append))
# introduce signed numbers
if signed:
p_curr = torch.cat((p_curr, torch.tensor([-curr_ele])))
if signed:
# sort tensor in reverse order before adding to list if signed
sorted, indices = torch.sort(p_curr, descending=True)
p_all.append(sorted)
else:
p_all.append(p_curr)
# gamma calculation:
# loop through all tensors
# if signed, add element at index 0 for each tensor
# else, add element at index 1 for each tensor
# gamma defined to ensure alpha is at max of range
p_sum = 0.0
for tens in p_all:
if signed:
p_sum += float(tens[0])
else:
p_sum += float(tens[1])
# assign gamma
gamma = alpha / p_sum
# calculate cartesian product
cartesian_product = list(itertools.product(*p_all))
quantization_levels_list = []
# calculate sum of each row
for row in cartesian_product:
sum = 0.0
for ele in row:
sum += ele
quantization_levels_list.append(sum)
quantization_levels_gamma = [float(gamma) * ele for ele in quantization_levels_list]
quantization_levels = torch.tensor(quantization_levels_gamma)
level_indices = torch.tensor([])
quantization_levels, level_indices = quantization_levels.sort()
return (alpha, gamma, quantization_levels, level_indices)
r"""Records the running minimum and maximum of ``x``."""
def forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach()
min_val, max_val = torch.aminmax(x)
if self.min_val.numel():
min_val = torch.min(min_val, self.min_val)
if self.max_val.numel():
max_val = torch.max(max_val, self.max_val)
self.min_val = min_val
self.max_val = max_val
return x_orig
def quant_levels_visualization(self, obs_result, filename):
xs = [float(x) / 1000.0 for x in range(1000)]
ys = [apot_to_float(float_to_apot(x, obs_result[1], obs_result[2]),
obs_result[1], obs_result[2]).item() for x in xs]
f = plt.figure(figsize=(15, 10))
plt.plot(xs, ys)
plt.title("APoT Quantization Plot")
plt.xlabel("Full Precision")
plt.ylabel("Quantized")
plt.show()