pytorch/torch/ao/quantization/experimental/observer.py
zhudada 96b0e7aaa6 [Code Clean] Clean asserts in torch/ao/quantization/experimental/* and torch/ao/quantization/pt2e/* (#165317)
Replace assert statements with explicit if/raise patterns in:
- torch/ao/quantization/experimental/* (11 errors)
- torch/ao/quantization/pt2e/* (68 errors)

fix partialy #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165317
Approved by: https://github.com/albanD
2025-10-20 23:07:11 +00:00

177 lines
5.8 KiB
Python

# mypy: allow-untyped-defs
"""
This module implements nonuniform observers used to collect statistics about
the values observed during calibration (PTQ) or training (QAT).
"""
import itertools
import torch
from torch.ao.quantization.experimental.apot_utils import apot_to_float, float_to_apot
from torch.ao.quantization.observer import ObserverBase
# TODO: Consider adding NonUniformQuantizationObserverBase class
# when more than one non-uniform method is implemented
class APoTObserver(ObserverBase):
b: int
k: int
n: int
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(self, b, k, dtype=torch.quint8) -> None:
super().__init__(dtype)
self.b = b
self.k = k
self.min_val = torch.tensor([])
self.max_val = torch.tensor([])
# min_val and max_val are optional args to override
# the min_val and max_val observed by forward
def calculate_qparams(self, signed): # type:ignore[override]
return self._calculate_qparams(signed, self.min_val, self.max_val)
r""" Calculates nonuniform quantization parameters according to APoT paper:
https://arxiv.org/pdf/1909.13144.pdf.
Arg:
signed: specifies whether to include signed values in quantization level calculations
min_val: optional arg that can override min_val internal attribute
max_val: optional arg that can override max_val internal attribute
Returns:
alpha: alpha quantization parameter, max of abs value of observed values
gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
quantization_levels: non-uniform quantization levels (fp representation)
level_indices: int representation of quantization_levels indices
"""
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
if min_val is not None:
self.min_val = min_val
if max_val is not None:
self.max_val = max_val
# compute alpha
alpha = torch.max(-self.min_val, self.max_val)
# check for valid inputs of b, k
if not self.k or self.k == 0:
raise AssertionError(f"k must be a non-zero integer, got k={self.k}")
if self.b % self.k != 0:
raise AssertionError(
f"b must be divisible by k, got b={self.b}, k={self.k}"
)
# compute n and store as member variable
self.n = self.b // self.k
# store a tensor of subtensors (all levels)
p_all = []
# create levels
for i in range(self.n):
p_curr = torch.tensor([0])
for j in range((2**self.k - 2) + 1):
curr_ele = 2 ** (-(i + j * self.n))
p_append = torch.tensor([curr_ele])
p_curr = torch.cat((p_curr, p_append))
# introduce signed numbers
if signed:
p_curr = torch.cat((p_curr, torch.tensor([-curr_ele])))
if signed:
# sort tensor in reverse order before adding to list if signed
sorted, _indices = torch.sort(p_curr, descending=True)
p_all.append(sorted)
else:
p_all.append(p_curr)
# gamma calculation:
# loop through all tensors
# if signed, add element at index 0 for each tensor
# else, add element at index 1 for each tensor
# gamma defined to ensure alpha is at max of range
p_sum = 0.0
for tens in p_all:
if signed:
p_sum += float(tens[0])
else:
p_sum += float(tens[1])
# assign gamma
gamma = alpha / p_sum
# calculate cartesian product
cartesian_product = list(itertools.product(*p_all))
quantization_levels_list = []
# calculate sum of each row
for row in cartesian_product:
sum = 0.0
for ele in row:
sum += ele
quantization_levels_list.append(sum)
quantization_levels_gamma = [
float(gamma) * ele for ele in quantization_levels_list
]
quantization_levels = torch.tensor(quantization_levels_gamma)
level_indices = torch.tensor([])
quantization_levels, level_indices = quantization_levels.sort()
return (alpha, gamma, quantization_levels, level_indices)
r"""Records the running minimum and maximum of ``x``.
Args:
x_orig: Tensor to be observed for min and max val"""
def forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach()
min_val, max_val = torch.aminmax(x)
if self.min_val.numel():
min_val = torch.min(min_val, self.min_val)
if self.max_val.numel():
max_val = torch.max(max_val, self.max_val)
self.min_val = min_val
self.max_val = max_val
return x_orig
r"""Displays visualization of APoT quantization levels
Args:
observer: APoTObserver to calculate qparams
signed: bool to indicate if qparams should be signed/unsigned
"""
def quant_levels_visualization(self, signed=False):
# matplotlib is optional dep
import matplotlib.pyplot as plt
alpha, _gamma, quantization_levels, level_indices = self.calculate_qparams(
signed
)
xs = [float(x) / 1000.0 for x in range(1000)]
ys = [
apot_to_float(
float_to_apot(x, quantization_levels, level_indices, alpha),
quantization_levels,
level_indices,
).item()
for x in xs
]
plt.figure(figsize=(15, 10))
plt.plot(xs, ys)
plt.title("APoT Quantization Plot")
plt.xlabel("Full Precision")
plt.ylabel("Quantized")
plt.show()