pytorch/torch/ao/quantization/experimental/quantizer.py
asl3 777c12f2df [quant] Modify APoT nonuniform quantization workflow (#80075)
### Summary:
This PR updates the design of APoT Observer, Quantizer, and Tensor to be more consistent with their uniform counterparts in the PyTorch framework. APoT Observer now calculates alpha as the max between the absolute values of the max and min values in the input tensor. APoT Quantizer is modified so its instance methods quantize_APoT and dequantize_APoT are called by their global method counterparts. APoT Tensor is modified to account for the new method definition of the `quantize_APoT` from APoT Quantizer.

### Test Plan:
Run APoT Observer class unit tests with: `python pytorch/test/quantization/core/experimental/test_nonuniform_observer.py`
Run APoT Quantize class unit tests with: `python pytorch/test/quantization/core/experimental/test_quantizer.py`
Run APoT Tensor class unit tests with: `python pytorch/test/quantization/core/experimental/test_quantized_tensor.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80075
Approved by: https://github.com/jerryzh168
2022-06-27 14:54:06 +00:00

95 lines
3.9 KiB
Python

import torch
from torch import Tensor
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
# class to store APoT quantizer and
# implement quantize and dequantize
class APoTQuantizer():
alpha: torch.Tensor
gamma: torch.Tensor
quantization_levels: torch.Tensor
level_indices: torch.Tensor
def __init__(
self,
alpha: torch.Tensor,
gamma: torch.Tensor,
quantization_levels: torch.Tensor,
level_indices: torch.Tensor) -> None:
self.alpha = alpha
self.gamma = gamma
self.quantization_levels = quantization_levels
self.level_indices = level_indices
r""" Quantizes fp Tensor to integer APoT representation.
Conversion is based on the qparams from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
tensor2quantize: fp Tensor
Returns:
result: APoT Tensor representation of tensor2quantize
"""
def quantize(self, tensor2quantize: Tensor):
result = torch.tensor([])
# clip tensor2quantize values based on alpha qparam
tensor2quantize = torch.clamp(tensor2quantize, -self.alpha, self.alpha)
# map float_to_apot over tensor2quantize elements
tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices))
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
result = TensorAPoT(self, tensor2quantize)
return result
r""" Dequantizes integer Tensor to floating point (fp) representation
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
apot_tensor: quantized APoT Tensor to dequantize
Returns:
result: fp representation of input Tensor
"""
def dequantize(self, apot_tensor) -> Tensor:
apot_tensor_data = apot_tensor.data
# map apot_to_float over tensor2quantize elements
result = apot_tensor_data.apply_(lambda x: float(apot_to_float(x, self.quantization_levels, self.level_indices)))
return result
def q_apot_alpha(self) -> float:
raise NotImplementedError
r""" Global method to create quantizer and call quantizer quantize_APoT
Args:
tensor2quantize: fp Tensor to quantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: ApoT Tensor representation of tensor2quantize
"""
def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor):
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = quantizer.quantize(tensor2quantize)
return result
r""" Global method to create quantizer and call quantizer dequantize_APoT
Args:
apot_tensor: APoT Tensor to dequantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: fp Tensor dequantized from apot_tensor
"""
def dequantize_APoT(apot_tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor) -> Tensor:
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = apot_tensor.quantizer.dequantize(apot_tensor)
return result