diff --git a/mypy.ini b/mypy.ini index c1f813b7397..98995175736 100644 --- a/mypy.ini +++ b/mypy.ini @@ -64,6 +64,9 @@ ignore_missing_imports = True [mypy-torch.ao.quantization.experimental.quantizer] ignore_missing_imports = True +[mypy-torch.ao.quantization.experimental.observer] +ignore_missing_imports = True + # # Files with various errors. Mostly real errors, possibly some false # positives as well. diff --git a/test/quantization/core/experimental/test_quantizer.py b/test/quantization/core/experimental/test_quantizer.py index ef4f60b36f3..c219c7fdad4 100644 --- a/test/quantization/core/experimental/test_quantizer.py +++ b/test/quantization/core/experimental/test_quantizer.py @@ -1,14 +1,76 @@ # Owner(s): ["oncall: quantization"] import torch +from torch import quantize_per_tensor from torch.ao.quantization.experimental.quantizer import APoTQuantizer import unittest +import random class TestQuantizer(unittest.TestCase): - def test_quantize_APoT(self): - t = torch.Tensor() - with self.assertRaises(NotImplementedError): - APoTQuantizer.quantize_APoT(t) + r""" Tests quantize_APoT result on random 1-dim tensor + and hardcoded values for b, k by comparing to uniform quantization + (non-uniform quantization reduces to uniform for k = 1) + quantized tensor (https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html) + * tensor2quantize: Tensor + * b: 4 + * k: 1 + """ + def test_quantize_APoT_rand_k1(self): + # generate random size of tensor2quantize between 1 -> 20 + size = random.randint(1, 20) + + # generate tensor with random fp values between 0 -> 1000 + tensor2quantize = 1000 * torch.rand(size, dtype=torch.float) + + quantizer = APoTQuantizer(4, 1, torch.max(tensor2quantize), False) + + # get apot quantized tensor result + qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize) + + # get uniform quantization quantized tensor result + uniform_quantized = quantize_per_tensor(input=tensor2quantize, scale=1.0, zero_point=0, dtype=torch.quint8).int_repr() + + qtensor_data = torch.tensor(qtensor).type(torch.uint8) + uniform_quantized_tensor = uniform_quantized.data + + self.assertTrue(torch.equal(qtensor_data, uniform_quantized_tensor)) + + r""" Tests quantize_APoT for k != 1. + Tests quantize_APoT result on random 1-dim tensor and hardcoded values for + b=4, k=2 by comparing results to hand-calculated results from APoT paper + https://arxiv.org/pdf/1909.13144.pdf + * tensor2quantize: Tensor + * b: 4 + * k: 2 + """ + def test_quantize_APoT_k2(self): + r""" + given b = 4, k = 2, alpha = 1.0, we know: + (from APoT paper example: https://arxiv.org/pdf/1909.13144.pdf) + + quantization_levels = tensor([0.0000, 0.0208, 0.0417, 0.0625, 0.0833, 0.1250, 0.1667, + 0.1875, 0.2500, 0.3333, 0.3750, 0.5000, 0.6667, 0.6875, 0.7500, 1.0000]) + + level_indices = tensor([ 0, 3, 12, 15, 2, 14, 8, 11, 10, 1, 13, 9, 4, 7, 6, 5])) + """ + + # generate tensor with random fp values between 0 -> 1000 + tensor2quantize = torch.tensor([0.0215, 0.1692, 0.385, 0.0391]) + + quantizer = APoTQuantizer(4, 2, 1.0, False) + + # get apot quantized tensor result + qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize) + qtensor_data = torch.tensor(qtensor).type(torch.uint8) + + # expected qtensor values calculated based on + # corresponding level_indices to nearest quantization level + # for each fp value in tensor2quantize + # e.g. + # 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices + expected_qtensor = torch.tensor([3, 8, 13, 12], dtype=torch.uint8) + + self.assertTrue(torch.equal(qtensor_data, expected_qtensor)) def test_dequantize(self): with self.assertRaises(NotImplementedError): diff --git a/torch/ao/quantization/experimental/quantizer.py b/torch/ao/quantization/experimental/quantizer.py index 7878a7fc7e0..bc6f7b2f17a 100644 --- a/torch/ao/quantization/experimental/quantizer.py +++ b/torch/ao/quantization/experimental/quantizer.py @@ -1,12 +1,55 @@ +import torch from torch import Tensor +from torch.ao.quantization.experimental.observer import APoTObserver +from torch.ao.quantization.experimental.apot_utils import float_to_apot # class to store APoT quantizer # implements quantize and dequantize # and stores all quantization parameters class APoTQuantizer(): - @staticmethod - def quantize_APoT(tensor2quantize: Tensor) -> Tensor: - raise NotImplementedError + b: int + k: int + n: int + signed: bool + quantization_levels: torch.Tensor + level_indices: torch.Tensor + + def __init__( + self, + b, + k, + max_val, + signed, + dtype=torch.quint8) -> None: + self.signed = signed + + # check for valid inputs of b, k + assert(k and k != 0) + assert(b % k == 0) + self.b = b + self.k = k + self.n = b // k + + # make observer, get quantizion levels and level indices + obs = APoTObserver(max_val=max_val, b=b, k=k) + obs_result = obs.calculate_qparams(signed=signed) + self.quantization_levels = obs_result[1] + self.level_indices = obs_result[2] + + r""" Quantizes fp Tensor to integer APoT representatio. + Conversion is based on the calculated quantization levels from a specified APoT non-uniform observer. + The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf. + Args: + tensor2quantize: fp Tensor + Returns: + result: integer APoT representation of tensor2quantize + """ + def quantize_APoT(self, tensor2quantize: Tensor): + result = torch.tensor([]) + # map float_to_apot over tensor2quantize elements + result = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices)) + + return result def dequantize(self) -> Tensor: raise NotImplementedError