From 368018530ef47153c1ca46ab97e5457511359eb8 Mon Sep 17 00:00:00 2001 From: asl3 Date: Mon, 18 Jul 2022 16:09:50 -0700 Subject: [PATCH] [quant] Implement forward and backward autograd functions for fake quantize (#81438) ### Summary: This PR implements custom autograd functions for forward and backward to be used in APoT fake quantization. The implementation follows this doc about custom autograd functions: https://pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html ### Test Plan: Run tests with: `python test/quantization/core/experimental/test_fake_quantize.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/81438 Approved by: https://github.com/jerryzh168 --- mypy.ini | 3 +++ .../core/experimental/test_fake_quantize.py | 16 +++++++++++ .../experimental/fake_quantize.py | 5 ++-- .../experimental/fake_quantize_function.py | 27 +++++++++++++++++++ .../ao/quantization/experimental/quantizer.py | 19 +++++++++---- 5 files changed, 62 insertions(+), 8 deletions(-) create mode 100644 torch/ao/quantization/experimental/fake_quantize_function.py diff --git a/mypy.ini b/mypy.ini index 636820e70b6..248f52f8586 100644 --- a/mypy.ini +++ b/mypy.ini @@ -70,6 +70,9 @@ ignore_missing_imports = True [mypy-torch.ao.quantization.experimental.APoT_tensor] ignore_missing_imports = True +[mypy-torch.ao.quantization.experimental.fake_quantize_function] +ignore_missing_imports = True + # # Files with various errors. Mostly real errors, possibly some false # positives as well. diff --git a/test/quantization/core/experimental/test_fake_quantize.py b/test/quantization/core/experimental/test_fake_quantize.py index 1f6ec1f4d66..fab526cdd56 100644 --- a/test/quantization/core/experimental/test_fake_quantize.py +++ b/test/quantization/core/experimental/test_fake_quantize.py @@ -5,6 +5,10 @@ import unittest from torch.ao.quantization.experimental.observer import APoTObserver from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize +from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function +forward_helper = fake_quantize_function.forward +backward = fake_quantize_function.backward +from torch.autograd import gradcheck class TestFakeQuantize(unittest.TestCase): r""" Tests fake quantize calculate_qparams() method @@ -72,5 +76,17 @@ class TestFakeQuantize(unittest.TestCase): with self.assertRaises(Exception): apot_fake.forward(torch.clone(X), False) + r""" Tests fake quantize helper backward() method + using torch.autograd.gradcheck function. + """ + def test_backward(self): + input = torch.randn(20, dtype=torch.double, requires_grad=True) + + observer = APoTObserver(b=4, k=2) + observer(input) + alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False) + + test = gradcheck(fake_quantize_function.apply, (input, alpha, gamma, quantization_levels, level_indices), atol=1e-4) + if __name__ == '__main__': unittest.main() diff --git a/torch/ao/quantization/experimental/fake_quantize.py b/torch/ao/quantization/experimental/fake_quantize.py index 37491b1f87a..c229859addb 100644 --- a/torch/ao/quantization/experimental/fake_quantize.py +++ b/torch/ao/quantization/experimental/fake_quantize.py @@ -1,8 +1,8 @@ import torch from torch import Tensor from torch.ao.quantization.experimental.observer import APoTObserver -from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT from torch.ao.quantization.fake_quantize import FakeQuantizeBase +from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function class APoTFakeQuantize(FakeQuantizeBase): alpha: Tensor @@ -28,7 +28,6 @@ class APoTFakeQuantize(FakeQuantizeBase): and self.quantization_levels is not None and self.level_indices is not None), "Must set qparams for fake quant" - X = quantize_APoT(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices) - X = dequantize_APoT(X) + X = fake_quantize_function.apply(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices) return X diff --git a/torch/ao/quantization/experimental/fake_quantize_function.py b/torch/ao/quantization/experimental/fake_quantize_function.py new file mode 100644 index 00000000000..cac01fd8c00 --- /dev/null +++ b/torch/ao/quantization/experimental/fake_quantize_function.py @@ -0,0 +1,27 @@ +import torch +from torch import Tensor +from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT + +class fake_quantize_function(torch.autograd.Function): + @staticmethod + def forward(ctx, # type: ignore[override] + x: Tensor, + alpha: Tensor, + gamma: Tensor, + quantization_levels: Tensor, + level_indices: Tensor) -> Tensor: + quantized_result = quantize_APoT(x, alpha, gamma, quantization_levels, level_indices) + + # calculate mask tensor + mask = x.detach().apply_(lambda x: (x <= alpha and x >= -alpha)) + + result = dequantize_APoT(quantized_result) + + ctx.save_for_backward(mask) + + return result + + @staticmethod + def backward(ctx, grad_output: Tensor) -> Tensor: # type: ignore[override] + mask = ctx.saved_tensors + return grad_output * mask diff --git a/torch/ao/quantization/experimental/quantizer.py b/torch/ao/quantization/experimental/quantizer.py index 5decd3d1811..46d792db814 100644 --- a/torch/ao/quantization/experimental/quantizer.py +++ b/torch/ao/quantization/experimental/quantizer.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +import numpy as np from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float # class to store APoT quantizer and @@ -33,10 +34,13 @@ class APoTQuantizer(): result = torch.tensor([]) # map float_to_apot over tensor2quantize elements - tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(x, - self.quantization_levels, - self.level_indices, - self.alpha)) + tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x, + self.quantization_levels, + self.level_indices, + self.alpha)) + + # convert to APoT int representation for dtype + tensor2quantize = tensor2quantize.int() from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT @@ -56,7 +60,12 @@ class APoTQuantizer(): apot_tensor_data = apot_tensor.data # map apot_to_float over tensor2quantize elements - result = apot_tensor_data.apply_(lambda x: float(apot_to_float(x, self.quantization_levels, self.level_indices))) + result_temp = np.empty(apot_tensor_data.size()) + for ele in apot_tensor_data: + new_ele = apot_to_float(ele, self.quantization_levels, self.level_indices) + np.append(result_temp, new_ele) + + result = torch.from_numpy(result_temp).int() return result