mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Replace assert statements with explicit if/raise patterns in: - torch/ao/quantization/experimental/* (11 errors) - torch/ao/quantization/pt2e/* (68 errors) fix partialy #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165317 Approved by: https://github.com/albanD
177 lines
5.8 KiB
Python
177 lines
5.8 KiB
Python
# mypy: allow-untyped-defs
|
|
"""
|
|
This module implements nonuniform observers used to collect statistics about
|
|
the values observed during calibration (PTQ) or training (QAT).
|
|
"""
|
|
|
|
import itertools
|
|
|
|
import torch
|
|
from torch.ao.quantization.experimental.apot_utils import apot_to_float, float_to_apot
|
|
from torch.ao.quantization.observer import ObserverBase
|
|
|
|
|
|
# TODO: Consider adding NonUniformQuantizationObserverBase class
|
|
# when more than one non-uniform method is implemented
|
|
|
|
|
|
class APoTObserver(ObserverBase):
|
|
b: int
|
|
k: int
|
|
n: int
|
|
min_val: torch.Tensor
|
|
max_val: torch.Tensor
|
|
|
|
def __init__(self, b, k, dtype=torch.quint8) -> None:
|
|
super().__init__(dtype)
|
|
self.b = b
|
|
self.k = k
|
|
|
|
self.min_val = torch.tensor([])
|
|
self.max_val = torch.tensor([])
|
|
|
|
# min_val and max_val are optional args to override
|
|
# the min_val and max_val observed by forward
|
|
def calculate_qparams(self, signed): # type:ignore[override]
|
|
return self._calculate_qparams(signed, self.min_val, self.max_val)
|
|
|
|
r""" Calculates nonuniform quantization parameters according to APoT paper:
|
|
https://arxiv.org/pdf/1909.13144.pdf.
|
|
Arg:
|
|
signed: specifies whether to include signed values in quantization level calculations
|
|
min_val: optional arg that can override min_val internal attribute
|
|
max_val: optional arg that can override max_val internal attribute
|
|
Returns:
|
|
alpha: alpha quantization parameter, max of abs value of observed values
|
|
gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
|
|
quantization_levels: non-uniform quantization levels (fp representation)
|
|
level_indices: int representation of quantization_levels indices
|
|
"""
|
|
|
|
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
|
|
if min_val is not None:
|
|
self.min_val = min_val
|
|
if max_val is not None:
|
|
self.max_val = max_val
|
|
|
|
# compute alpha
|
|
alpha = torch.max(-self.min_val, self.max_val)
|
|
|
|
# check for valid inputs of b, k
|
|
if not self.k or self.k == 0:
|
|
raise AssertionError(f"k must be a non-zero integer, got k={self.k}")
|
|
if self.b % self.k != 0:
|
|
raise AssertionError(
|
|
f"b must be divisible by k, got b={self.b}, k={self.k}"
|
|
)
|
|
|
|
# compute n and store as member variable
|
|
self.n = self.b // self.k
|
|
|
|
# store a tensor of subtensors (all levels)
|
|
p_all = []
|
|
|
|
# create levels
|
|
for i in range(self.n):
|
|
p_curr = torch.tensor([0])
|
|
|
|
for j in range((2**self.k - 2) + 1):
|
|
curr_ele = 2 ** (-(i + j * self.n))
|
|
p_append = torch.tensor([curr_ele])
|
|
p_curr = torch.cat((p_curr, p_append))
|
|
# introduce signed numbers
|
|
if signed:
|
|
p_curr = torch.cat((p_curr, torch.tensor([-curr_ele])))
|
|
|
|
if signed:
|
|
# sort tensor in reverse order before adding to list if signed
|
|
sorted, _indices = torch.sort(p_curr, descending=True)
|
|
p_all.append(sorted)
|
|
else:
|
|
p_all.append(p_curr)
|
|
|
|
# gamma calculation:
|
|
# loop through all tensors
|
|
# if signed, add element at index 0 for each tensor
|
|
# else, add element at index 1 for each tensor
|
|
# gamma defined to ensure alpha is at max of range
|
|
p_sum = 0.0
|
|
for tens in p_all:
|
|
if signed:
|
|
p_sum += float(tens[0])
|
|
else:
|
|
p_sum += float(tens[1])
|
|
|
|
# assign gamma
|
|
gamma = alpha / p_sum
|
|
|
|
# calculate cartesian product
|
|
cartesian_product = list(itertools.product(*p_all))
|
|
|
|
quantization_levels_list = []
|
|
|
|
# calculate sum of each row
|
|
for row in cartesian_product:
|
|
sum = 0.0
|
|
for ele in row:
|
|
sum += ele
|
|
quantization_levels_list.append(sum)
|
|
|
|
quantization_levels_gamma = [
|
|
float(gamma) * ele for ele in quantization_levels_list
|
|
]
|
|
quantization_levels = torch.tensor(quantization_levels_gamma)
|
|
level_indices = torch.tensor([])
|
|
quantization_levels, level_indices = quantization_levels.sort()
|
|
|
|
return (alpha, gamma, quantization_levels, level_indices)
|
|
|
|
r"""Records the running minimum and maximum of ``x``.
|
|
Args:
|
|
x_orig: Tensor to be observed for min and max val"""
|
|
|
|
def forward(self, x_orig):
|
|
if x_orig.numel() == 0:
|
|
return x_orig
|
|
x = x_orig.detach()
|
|
min_val, max_val = torch.aminmax(x)
|
|
if self.min_val.numel():
|
|
min_val = torch.min(min_val, self.min_val)
|
|
if self.max_val.numel():
|
|
max_val = torch.max(max_val, self.max_val)
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
return x_orig
|
|
|
|
r"""Displays visualization of APoT quantization levels
|
|
Args:
|
|
observer: APoTObserver to calculate qparams
|
|
signed: bool to indicate if qparams should be signed/unsigned
|
|
"""
|
|
|
|
def quant_levels_visualization(self, signed=False):
|
|
# matplotlib is optional dep
|
|
import matplotlib.pyplot as plt
|
|
|
|
alpha, _gamma, quantization_levels, level_indices = self.calculate_qparams(
|
|
signed
|
|
)
|
|
|
|
xs = [float(x) / 1000.0 for x in range(1000)]
|
|
ys = [
|
|
apot_to_float(
|
|
float_to_apot(x, quantization_levels, level_indices, alpha),
|
|
quantization_levels,
|
|
level_indices,
|
|
).item()
|
|
for x in xs
|
|
]
|
|
|
|
plt.figure(figsize=(15, 10))
|
|
|
|
plt.plot(xs, ys)
|
|
plt.title("APoT Quantization Plot")
|
|
plt.xlabel("Full Precision")
|
|
plt.ylabel("Quantized")
|
|
plt.show()
|