mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant] Implement forward and backward autograd functions for fake quantize (#81438)
### Summary: This PR implements custom autograd functions for forward and backward to be used in APoT fake quantization. The implementation follows this doc about custom autograd functions: https://pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html ### Test Plan: Run tests with: `python test/quantization/core/experimental/test_fake_quantize.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/81438 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
4aac42cc98
commit
368018530e
3
mypy.ini
3
mypy.ini
|
|
@ -70,6 +70,9 @@ ignore_missing_imports = True
|
||||||
[mypy-torch.ao.quantization.experimental.APoT_tensor]
|
[mypy-torch.ao.quantization.experimental.APoT_tensor]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-torch.ao.quantization.experimental.fake_quantize_function]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
#
|
#
|
||||||
# Files with various errors. Mostly real errors, possibly some false
|
# Files with various errors. Mostly real errors, possibly some false
|
||||||
# positives as well.
|
# positives as well.
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,10 @@ import unittest
|
||||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||||
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
|
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
|
||||||
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
|
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
|
||||||
|
from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function
|
||||||
|
forward_helper = fake_quantize_function.forward
|
||||||
|
backward = fake_quantize_function.backward
|
||||||
|
from torch.autograd import gradcheck
|
||||||
|
|
||||||
class TestFakeQuantize(unittest.TestCase):
|
class TestFakeQuantize(unittest.TestCase):
|
||||||
r""" Tests fake quantize calculate_qparams() method
|
r""" Tests fake quantize calculate_qparams() method
|
||||||
|
|
@ -72,5 +76,17 @@ class TestFakeQuantize(unittest.TestCase):
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
apot_fake.forward(torch.clone(X), False)
|
apot_fake.forward(torch.clone(X), False)
|
||||||
|
|
||||||
|
r""" Tests fake quantize helper backward() method
|
||||||
|
using torch.autograd.gradcheck function.
|
||||||
|
"""
|
||||||
|
def test_backward(self):
|
||||||
|
input = torch.randn(20, dtype=torch.double, requires_grad=True)
|
||||||
|
|
||||||
|
observer = APoTObserver(b=4, k=2)
|
||||||
|
observer(input)
|
||||||
|
alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False)
|
||||||
|
|
||||||
|
test = gradcheck(fake_quantize_function.apply, (input, alpha, gamma, quantization_levels, level_indices), atol=1e-4)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||||
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
|
|
||||||
from torch.ao.quantization.fake_quantize import FakeQuantizeBase
|
from torch.ao.quantization.fake_quantize import FakeQuantizeBase
|
||||||
|
from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function
|
||||||
|
|
||||||
class APoTFakeQuantize(FakeQuantizeBase):
|
class APoTFakeQuantize(FakeQuantizeBase):
|
||||||
alpha: Tensor
|
alpha: Tensor
|
||||||
|
|
@ -28,7 +28,6 @@ class APoTFakeQuantize(FakeQuantizeBase):
|
||||||
and self.quantization_levels is not None
|
and self.quantization_levels is not None
|
||||||
and self.level_indices is not None), "Must set qparams for fake quant"
|
and self.level_indices is not None), "Must set qparams for fake quant"
|
||||||
|
|
||||||
X = quantize_APoT(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
|
X = fake_quantize_function.apply(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
|
||||||
X = dequantize_APoT(X)
|
|
||||||
|
|
||||||
return X
|
return X
|
||||||
|
|
|
||||||
27
torch/ao/quantization/experimental/fake_quantize_function.py
Normal file
27
torch/ao/quantization/experimental/fake_quantize_function.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT
|
||||||
|
|
||||||
|
class fake_quantize_function(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, # type: ignore[override]
|
||||||
|
x: Tensor,
|
||||||
|
alpha: Tensor,
|
||||||
|
gamma: Tensor,
|
||||||
|
quantization_levels: Tensor,
|
||||||
|
level_indices: Tensor) -> Tensor:
|
||||||
|
quantized_result = quantize_APoT(x, alpha, gamma, quantization_levels, level_indices)
|
||||||
|
|
||||||
|
# calculate mask tensor
|
||||||
|
mask = x.detach().apply_(lambda x: (x <= alpha and x >= -alpha))
|
||||||
|
|
||||||
|
result = dequantize_APoT(quantized_result)
|
||||||
|
|
||||||
|
ctx.save_for_backward(mask)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output: Tensor) -> Tensor: # type: ignore[override]
|
||||||
|
mask = ctx.saved_tensors
|
||||||
|
return grad_output * mask
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
import numpy as np
|
||||||
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
||||||
|
|
||||||
# class to store APoT quantizer and
|
# class to store APoT quantizer and
|
||||||
|
|
@ -33,10 +34,13 @@ class APoTQuantizer():
|
||||||
result = torch.tensor([])
|
result = torch.tensor([])
|
||||||
|
|
||||||
# map float_to_apot over tensor2quantize elements
|
# map float_to_apot over tensor2quantize elements
|
||||||
tensor2quantize = tensor2quantize.apply_(lambda x: float_to_apot(x,
|
tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x,
|
||||||
self.quantization_levels,
|
self.quantization_levels,
|
||||||
self.level_indices,
|
self.level_indices,
|
||||||
self.alpha))
|
self.alpha))
|
||||||
|
|
||||||
|
# convert to APoT int representation for dtype
|
||||||
|
tensor2quantize = tensor2quantize.int()
|
||||||
|
|
||||||
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
||||||
|
|
||||||
|
|
@ -56,7 +60,12 @@ class APoTQuantizer():
|
||||||
apot_tensor_data = apot_tensor.data
|
apot_tensor_data = apot_tensor.data
|
||||||
|
|
||||||
# map apot_to_float over tensor2quantize elements
|
# map apot_to_float over tensor2quantize elements
|
||||||
result = apot_tensor_data.apply_(lambda x: float(apot_to_float(x, self.quantization_levels, self.level_indices)))
|
result_temp = np.empty(apot_tensor_data.size())
|
||||||
|
for ele in apot_tensor_data:
|
||||||
|
new_ele = apot_to_float(ele, self.quantization_levels, self.level_indices)
|
||||||
|
np.append(result_temp, new_ele)
|
||||||
|
|
||||||
|
result = torch.from_numpy(result_temp).int()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user