# mypy: allow-untyped-defs """ This module implements nonuniform observers used to collect statistics about the values observed during calibration (PTQ) or training (QAT). """ import itertools import torch from torch.ao.quantization.experimental.apot_utils import apot_to_float, float_to_apot from torch.ao.quantization.observer import ObserverBase # TODO: Consider adding NonUniformQuantizationObserverBase class # when more than one non-uniform method is implemented class APoTObserver(ObserverBase): b: int k: int n: int min_val: torch.Tensor max_val: torch.Tensor def __init__(self, b, k, dtype=torch.quint8) -> None: super().__init__(dtype) self.b = b self.k = k self.min_val = torch.tensor([]) self.max_val = torch.tensor([]) # min_val and max_val are optional args to override # the min_val and max_val observed by forward def calculate_qparams(self, signed): # type:ignore[override] return self._calculate_qparams(signed, self.min_val, self.max_val) r""" Calculates nonuniform quantization parameters according to APoT paper: https://arxiv.org/pdf/1909.13144.pdf. Arg: signed: specifies whether to include signed values in quantization level calculations min_val: optional arg that can override min_val internal attribute max_val: optional arg that can override max_val internal attribute Returns: alpha: alpha quantization parameter, max of abs value of observed values gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range quantization_levels: non-uniform quantization levels (fp representation) level_indices: int representation of quantization_levels indices """ def _calculate_qparams(self, signed: bool, min_val=None, max_val=None): if min_val is not None: self.min_val = min_val if max_val is not None: self.max_val = max_val # compute alpha alpha = torch.max(-self.min_val, self.max_val) # check for valid inputs of b, k if not self.k or self.k == 0: raise AssertionError(f"k must be a non-zero integer, got k={self.k}") if self.b % self.k != 0: raise AssertionError( f"b must be divisible by k, got b={self.b}, k={self.k}" ) # compute n and store as member variable self.n = self.b // self.k # store a tensor of subtensors (all levels) p_all = [] # create levels for i in range(self.n): p_curr = torch.tensor([0]) for j in range((2**self.k - 2) + 1): curr_ele = 2 ** (-(i + j * self.n)) p_append = torch.tensor([curr_ele]) p_curr = torch.cat((p_curr, p_append)) # introduce signed numbers if signed: p_curr = torch.cat((p_curr, torch.tensor([-curr_ele]))) if signed: # sort tensor in reverse order before adding to list if signed sorted, _indices = torch.sort(p_curr, descending=True) p_all.append(sorted) else: p_all.append(p_curr) # gamma calculation: # loop through all tensors # if signed, add element at index 0 for each tensor # else, add element at index 1 for each tensor # gamma defined to ensure alpha is at max of range p_sum = 0.0 for tens in p_all: if signed: p_sum += float(tens[0]) else: p_sum += float(tens[1]) # assign gamma gamma = alpha / p_sum # calculate cartesian product cartesian_product = list(itertools.product(*p_all)) quantization_levels_list = [] # calculate sum of each row for row in cartesian_product: sum = 0.0 for ele in row: sum += ele quantization_levels_list.append(sum) quantization_levels_gamma = [ float(gamma) * ele for ele in quantization_levels_list ] quantization_levels = torch.tensor(quantization_levels_gamma) level_indices = torch.tensor([]) quantization_levels, level_indices = quantization_levels.sort() return (alpha, gamma, quantization_levels, level_indices) r"""Records the running minimum and maximum of ``x``. Args: x_orig: Tensor to be observed for min and max val""" def forward(self, x_orig): if x_orig.numel() == 0: return x_orig x = x_orig.detach() min_val, max_val = torch.aminmax(x) if self.min_val.numel(): min_val = torch.min(min_val, self.min_val) if self.max_val.numel(): max_val = torch.max(max_val, self.max_val) self.min_val = min_val self.max_val = max_val return x_orig r"""Displays visualization of APoT quantization levels Args: observer: APoTObserver to calculate qparams signed: bool to indicate if qparams should be signed/unsigned """ def quant_levels_visualization(self, signed=False): # matplotlib is optional dep import matplotlib.pyplot as plt alpha, _gamma, quantization_levels, level_indices = self.calculate_qparams( signed ) xs = [float(x) / 1000.0 for x in range(1000)] ys = [ apot_to_float( float_to_apot(x, quantization_levels, level_indices, alpha), quantization_levels, level_indices, ).item() for x in xs ] plt.figure(figsize=(15, 10)) plt.plot(xs, ys) plt.title("APoT Quantization Plot") plt.xlabel("Full Precision") plt.ylabel("Quantized") plt.show()