[quant] Add quantizer class skeleton

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79936

Approved by: https://github.com/jerryzh168
This commit is contained in:
asl3 2022-06-21 19:29:25 -07:00 committed by PyTorch MergeBot
parent 2148e6b4a4
commit f89e640810
5 changed files with 49 additions and 20 deletions

View File

@ -61,6 +61,9 @@ ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.apot_utils] [mypy-torch.ao.quantization.experimental.apot_utils]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.quantizer]
ignore_missing_imports = True
# #
# Files with various errors. Mostly real errors, possibly some false # Files with various errors. Mostly real errors, possibly some false
# positives as well. # positives as well.

View File

@ -1,22 +1,13 @@
# Owner(s): ["oncall: quantization"] # Owner(s): ["oncall: quantization"]
import torch from torch.ao.quantization.experimental.quantizer import APoTQuantizer
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
import unittest import unittest
class TestQuantizedTensor(unittest.TestCase): class TestQuantizedTensor(unittest.TestCase):
def test_quantize_APoT(self): def test_int_repr(self):
t = torch.Tensor() quantizer = APoTQuantizer()
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
TensorAPoT.quantize_APoT(t) quantizer.int_repr()
def test_dequantize(self):
with self.assertRaises(NotImplementedError):
TensorAPoT.dequantize(self)
def test_q_apot_alpha(self):
with self.assertRaises(NotImplementedError):
TensorAPoT.q_apot_alpha(self)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -0,0 +1,22 @@
# Owner(s): ["oncall: quantization"]
import torch
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
import unittest
class TestQuantizer(unittest.TestCase):
def test_quantize_APoT(self):
t = torch.Tensor()
with self.assertRaises(NotImplementedError):
APoTQuantizer.quantize_APoT(t)
def test_dequantize(self):
with self.assertRaises(NotImplementedError):
APoTQuantizer.dequantize(self)
def test_q_apot_alpha(self):
with self.assertRaises(NotImplementedError):
APoTQuantizer.q_apot_alpha(self)
if __name__ == '__main__':
unittest.main()

View File

@ -1,14 +1,12 @@
import torch import torch
from torch import Tensor from torch.ao.quantization.experimental.quantizer import APoTQuantizer
# class to store APoT quantized tensor # class to store APoT quantized tensor
class TensorAPoT(torch.Tensor): class TensorAPoT(torch.Tensor):
@staticmethod quantizer: APoTQuantizer
def quantize_APoT(tensor2quantize: Tensor) -> Tensor:
def __init__(self, quantizer):
raise NotImplementedError raise NotImplementedError
def dequantize(self) -> Tensor: def int_repr(self):
raise NotImplementedError
def q_apot_alpha(self) -> float:
raise NotImplementedError raise NotImplementedError

View File

@ -0,0 +1,15 @@
from torch import Tensor
# class to store APoT quantizer
# implements quantize and dequantize
# and stores all quantization parameters
class APoTQuantizer():
@staticmethod
def quantize_APoT(tensor2quantize: Tensor) -> Tensor:
raise NotImplementedError
def dequantize(self) -> Tensor:
raise NotImplementedError
def q_apot_alpha(self) -> float:
raise NotImplementedError