mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
### Summary Refactors quantization levels visualization function to include alpha qparam in parameters of `float_to_apot` function call (due to `float_to_apot` function update). Also adds additional detail to the documentation for `quant_levels_visualization`. ### Test Plan Print visualization by calling `quant_levels_visualization` function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/82790 Approved by: https://github.com/jerryzh168
158 lines
5.4 KiB
Python
158 lines
5.4 KiB
Python
"""
|
|
This module implements nonuniform observers used to collect statistics about
|
|
the values observed during calibration (PTQ) or training (QAT).
|
|
"""
|
|
|
|
import torch
|
|
import itertools
|
|
import matplotlib.pyplot as plt
|
|
from torch.ao.quantization.observer import ObserverBase
|
|
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
|
|
|
# TODO: Consider adding NonUniformQuantizationObserverBase class
|
|
# when more than one non-uniform method is implemented
|
|
|
|
class APoTObserver(ObserverBase):
|
|
b: int
|
|
k: int
|
|
n: int
|
|
min_val: torch.Tensor
|
|
max_val: torch.Tensor
|
|
|
|
def __init__(
|
|
self,
|
|
b,
|
|
k,
|
|
dtype=torch.quint8) -> None:
|
|
super().__init__(dtype)
|
|
self.b = b
|
|
self.k = k
|
|
|
|
self.min_val = torch.tensor([])
|
|
self.max_val = torch.tensor([])
|
|
|
|
# min_val and max_val are optional args to override
|
|
# the min_val and max_val observed by forward
|
|
def calculate_qparams(self, signed):
|
|
return self._calculate_qparams(signed, self.min_val, self.max_val)
|
|
|
|
r""" Calculates nonuniform quantization parameters according to APoT paper:
|
|
https://arxiv.org/pdf/1909.13144.pdf.
|
|
Arg:
|
|
signed: specifies whether to include signed values in quantization level calculations
|
|
min_val: optional arg that can override min_val internal attribute
|
|
max_val: optional arg that can override max_val internal attribute
|
|
Returns:
|
|
alpha: alpha quantization parameter, max of abs value of observed values
|
|
gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
|
|
quantization_levels: non-uniform quantization levels (fp representation)
|
|
level_indices: int representation of quantization_levels indices
|
|
"""
|
|
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
|
|
if min_val is not None:
|
|
self.min_val = min_val
|
|
if max_val is not None:
|
|
self.max_val = max_val
|
|
|
|
# compute alpha
|
|
alpha = torch.max(-self.min_val, self.max_val)
|
|
|
|
# check for valid inputs of b, k
|
|
assert(self.k and self.k != 0)
|
|
assert(self.b % self.k == 0)
|
|
|
|
# compute n and store as member variable
|
|
self.n = self.b // self.k
|
|
|
|
# store a tensor of subtensors (all levels)
|
|
p_all = []
|
|
|
|
# create levels
|
|
for i in range(0, self.n):
|
|
p_curr = torch.tensor([0])
|
|
|
|
for j in range(0, (2 ** self.k - 2) + 1):
|
|
curr_ele = 2 ** (- (i + j * self.n))
|
|
p_append = torch.tensor([curr_ele])
|
|
p_curr = torch.cat((p_curr, p_append))
|
|
# introduce signed numbers
|
|
if signed:
|
|
p_curr = torch.cat((p_curr, torch.tensor([-curr_ele])))
|
|
|
|
if signed:
|
|
# sort tensor in reverse order before adding to list if signed
|
|
sorted, indices = torch.sort(p_curr, descending=True)
|
|
p_all.append(sorted)
|
|
else:
|
|
p_all.append(p_curr)
|
|
|
|
# gamma calculation:
|
|
# loop through all tensors
|
|
# if signed, add element at index 0 for each tensor
|
|
# else, add element at index 1 for each tensor
|
|
# gamma defined to ensure alpha is at max of range
|
|
p_sum = 0.0
|
|
for tens in p_all:
|
|
if signed:
|
|
p_sum += float(tens[0])
|
|
else:
|
|
p_sum += float(tens[1])
|
|
|
|
# assign gamma
|
|
gamma = alpha / p_sum
|
|
|
|
# calculate cartesian product
|
|
cartesian_product = list(itertools.product(*p_all))
|
|
|
|
quantization_levels_list = []
|
|
|
|
# calculate sum of each row
|
|
for row in cartesian_product:
|
|
sum = 0.0
|
|
for ele in row:
|
|
sum += ele
|
|
quantization_levels_list.append(sum)
|
|
|
|
quantization_levels_gamma = [float(gamma) * ele for ele in quantization_levels_list]
|
|
quantization_levels = torch.tensor(quantization_levels_gamma)
|
|
level_indices = torch.tensor([])
|
|
quantization_levels, level_indices = quantization_levels.sort()
|
|
|
|
return (alpha, gamma, quantization_levels, level_indices)
|
|
|
|
r"""Records the running minimum and maximum of ``x``.
|
|
Args:
|
|
x_orig: Tensor to be observed for min and max val"""
|
|
def forward(self, x_orig):
|
|
if x_orig.numel() == 0:
|
|
return x_orig
|
|
x = x_orig.detach()
|
|
min_val, max_val = torch.aminmax(x)
|
|
if self.min_val.numel():
|
|
min_val = torch.min(min_val, self.min_val)
|
|
if self.max_val.numel():
|
|
max_val = torch.max(max_val, self.max_val)
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
return x_orig
|
|
|
|
r"""Displays visualization of APoT quantization levels
|
|
Args:
|
|
observer: APoTObserver to calculate qparams
|
|
signed: bool to indicate if qparams should be signed/unsigned
|
|
"""
|
|
def quant_levels_visualization(self, signed=False):
|
|
alpha, gamma, quantization_levels, level_indices = self.calculate_qparams(signed)
|
|
|
|
xs = [float(x) / 1000.0 for x in range(1000)]
|
|
ys = [apot_to_float(float_to_apot(x, quantization_levels, level_indices, alpha),
|
|
quantization_levels, level_indices).item() for x in xs]
|
|
|
|
f = plt.figure(figsize=(15, 10))
|
|
|
|
plt.plot(xs, ys)
|
|
plt.title("APoT Quantization Plot")
|
|
plt.xlabel("Full Precision")
|
|
plt.ylabel("Quantized")
|
|
plt.show()
|