mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
### Summary: This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight. According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below). ### Test Plan: Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py` ### Accuracy Stats: 8-bit (Uniform int8, APoT b = 8 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5) 4-bit (Uniform int4, APoT b = 4 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5) **Full Precision model (FX Graph Mode quantized)** Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5) **Eager mode quantized model** Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5) Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040 Approved by: https://github.com/jerryzh168
137 lines
5.5 KiB
Python
137 lines
5.5 KiB
Python
import torch
|
|
from torch import Tensor
|
|
import numpy as np
|
|
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util
|
|
|
|
# class to store APoT quantizer and
|
|
# implement quantize and dequantize
|
|
class APoTQuantizer():
|
|
alpha: torch.Tensor
|
|
gamma: torch.Tensor
|
|
quantization_levels: torch.Tensor
|
|
level_indices: torch.Tensor
|
|
|
|
def __init__(
|
|
self,
|
|
alpha: torch.Tensor,
|
|
gamma: torch.Tensor,
|
|
quantization_levels: torch.Tensor,
|
|
level_indices: torch.Tensor) -> None:
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
self.quantization_levels = quantization_levels
|
|
self.level_indices = level_indices
|
|
|
|
r""" Quantizes fp Tensor to integer APoT representation.
|
|
Conversion is based on the qparams from a specified APoT non-uniform observer.
|
|
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
|
Args:
|
|
tensor2quantize: fp Tensor
|
|
Returns:
|
|
result: APoT Tensor representation of tensor2quantize
|
|
"""
|
|
def quantize(self, tensor2quantize: Tensor):
|
|
result = torch.tensor([])
|
|
|
|
# map float_to_apot over tensor2quantize elements
|
|
tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x,
|
|
self.quantization_levels,
|
|
self.level_indices,
|
|
self.alpha))
|
|
|
|
# convert to APoT int representation for dtype
|
|
tensor2quantize = tensor2quantize.int()
|
|
|
|
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
|
|
|
result = TensorAPoT(self, tensor2quantize)
|
|
|
|
return result
|
|
|
|
r""" Dequantizes integer Tensor to floating point (fp) representation
|
|
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
|
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
|
Args:
|
|
tensor2quantize: fp Tensor
|
|
Returns:
|
|
result: fp reduced precision representation of input Tensor
|
|
"""
|
|
def dequantize(self, apot_tensor) -> Tensor:
|
|
orig_size = apot_tensor.data.size()
|
|
apot_tensor_data = apot_tensor.data.flatten()
|
|
|
|
print(apot_tensor_data)
|
|
|
|
# map apot_to_float over tensor2quantize elements
|
|
result_temp = np.empty(shape=apot_tensor_data.size())
|
|
for i in range(len(apot_tensor_data)):
|
|
new_ele = apot_to_float(apot_tensor_data[i], self.quantization_levels, self.level_indices)
|
|
result_temp[i] = new_ele
|
|
|
|
result = torch.from_numpy(result_temp).reshape(orig_size)
|
|
|
|
return result
|
|
|
|
r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
|
|
based on the calculated quantization levels from a specified APoT non-uniform observer.
|
|
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
|
Args:
|
|
apot_tensor: quantized APoT Tensor to dequantize
|
|
Returns:
|
|
result: fp representation of input Tensor
|
|
"""
|
|
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
|
|
levels_lst = list(self.quantization_levels)
|
|
|
|
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
|
|
|
|
return result
|
|
|
|
def q_apot_alpha(self) -> float:
|
|
raise NotImplementedError
|
|
|
|
r""" Global method to create quantizer and call quantizer quantize_APoT
|
|
Args:
|
|
tensor2quantize: fp Tensor to quantize
|
|
alpha: Tensor qparam alpha (clipping level)
|
|
gamma: Tensor qparam gamma (scale factor for quantization levels)
|
|
quantization levels: Tensor with fp quantization levels
|
|
level indices: Tensor with integer quantization level indices
|
|
Returns:
|
|
result: ApoT Tensor representation of tensor2quantize
|
|
"""
|
|
def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor):
|
|
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
|
|
result = quantizer.quantize(tensor2quantize)
|
|
return result
|
|
|
|
r""" Global method to create quantizer and call quantizer dequantize_APoT
|
|
Args:
|
|
apot_tensor: APoT Tensor to dequantize
|
|
Returns:
|
|
result: fp Tensor dequantized from apot_tensor
|
|
"""
|
|
def dequantize_APoT(apot_tensor) -> Tensor:
|
|
quantizer = apot_tensor.quantizer
|
|
result = quantizer.dequantize(apot_tensor)
|
|
return result
|
|
|
|
r""" Global method to create quantizer and call quantizer quant_dequant
|
|
Args:
|
|
tensor2quantize: fp Tensor to quantize
|
|
alpha: Tensor qparam alpha (clipping level)
|
|
gamma: Tensor qparam gamma (scale factor for quantization levels)
|
|
quantization levels: Tensor with fp quantization levels
|
|
level indices: Tensor with integer quantization level indices
|
|
Returns:
|
|
result: fp reduced precision Tensor from tensor2quantize
|
|
"""
|
|
def quant_dequant_APoT(tensor2quantize: Tensor,
|
|
alpha: Tensor,
|
|
gamma: Tensor,
|
|
quantization_levels: Tensor,
|
|
level_indices: Tensor) -> Tensor:
|
|
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
|
|
result = quantizer.quant_dequant(tensor2quantize)
|
|
return result
|