mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
### Summary: This PR updates the design of APoT Observer, Quantizer, and Tensor to be more consistent with their uniform counterparts in the PyTorch framework. APoT Observer now calculates alpha as the max between the absolute values of the max and min values in the input tensor. APoT Quantizer is modified so its instance methods quantize_APoT and dequantize_APoT are called by their global method counterparts. APoT Tensor is modified to account for the new method definition of the `quantize_APoT` from APoT Quantizer. ### Test Plan: Run APoT Observer class unit tests with: `python pytorch/test/quantization/core/experimental/test_nonuniform_observer.py` Run APoT Quantize class unit tests with: `python pytorch/test/quantization/core/experimental/test_quantizer.py` Run APoT Tensor class unit tests with: `python pytorch/test/quantization/core/experimental/test_quantized_tensor.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/80075 Approved by: https://github.com/jerryzh168
95 lines
3.9 KiB
Python
95 lines
3.9 KiB
Python
import torch
|
|
from torch import Tensor
|
|
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
|
|
|
# class to store APoT quantizer and
|
|
# implement quantize and dequantize
|
|
class APoTQuantizer():
|
|
alpha: torch.Tensor
|
|
gamma: torch.Tensor
|
|
quantization_levels: torch.Tensor
|
|
level_indices: torch.Tensor
|
|
|
|
def __init__(
|
|
self,
|
|
alpha: torch.Tensor,
|
|
gamma: torch.Tensor,
|
|
quantization_levels: torch.Tensor,
|
|
level_indices: torch.Tensor) -> None:
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
self.quantization_levels = quantization_levels
|
|
self.level_indices = level_indices
|
|
|
|
r""" Quantizes fp Tensor to integer APoT representation.
|
|
Conversion is based on the qparams from a specified APoT non-uniform observer.
|
|
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
|
Args:
|
|
tensor2quantize: fp Tensor
|
|
Returns:
|
|
result: APoT Tensor representation of tensor2quantize
|
|
"""
|
|
def quantize(self, tensor2quantize: Tensor):
|
|
result = torch.tensor([])
|
|
|
|
# clip tensor2quantize values based on alpha qparam
|
|
tensor2quantize = torch.clamp(tensor2quantize, -self.alpha, self.alpha)
|
|
|
|
# map float_to_apot over tensor2quantize elements
|
|
tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices))
|
|
|
|
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
|
|
|
result = TensorAPoT(self, tensor2quantize)
|
|
|
|
return result
|
|
|
|
r""" Dequantizes integer Tensor to floating point (fp) representation
|
|
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
|
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
|
Args:
|
|
apot_tensor: quantized APoT Tensor to dequantize
|
|
Returns:
|
|
result: fp representation of input Tensor
|
|
"""
|
|
def dequantize(self, apot_tensor) -> Tensor:
|
|
apot_tensor_data = apot_tensor.data
|
|
|
|
# map apot_to_float over tensor2quantize elements
|
|
result = apot_tensor_data.apply_(lambda x: float(apot_to_float(x, self.quantization_levels, self.level_indices)))
|
|
|
|
return result
|
|
|
|
def q_apot_alpha(self) -> float:
|
|
raise NotImplementedError
|
|
|
|
r""" Global method to create quantizer and call quantizer quantize_APoT
|
|
Args:
|
|
tensor2quantize: fp Tensor to quantize
|
|
alpha: Tensor qparam alpha (clipping level)
|
|
gamma: Tensor qparam gamma (scale factor for quantization levels)
|
|
quantization levels: Tensor with fp quantization levels
|
|
level indices: Tensor with integer quantization level indices
|
|
Returns:
|
|
result: ApoT Tensor representation of tensor2quantize
|
|
"""
|
|
def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor):
|
|
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
|
|
result = quantizer.quantize(tensor2quantize)
|
|
return result
|
|
|
|
r""" Global method to create quantizer and call quantizer dequantize_APoT
|
|
Args:
|
|
apot_tensor: APoT Tensor to dequantize
|
|
alpha: Tensor qparam alpha (clipping level)
|
|
gamma: Tensor qparam gamma (scale factor for quantization levels)
|
|
quantization levels: Tensor with fp quantization levels
|
|
level indices: Tensor with integer quantization level indices
|
|
Returns:
|
|
result: fp Tensor dequantized from apot_tensor
|
|
"""
|
|
def dequantize_APoT(apot_tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor) -> Tensor:
|
|
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
|
|
result = apot_tensor.quantizer.dequantize(apot_tensor)
|
|
return result
|