mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant] Add quantizer class skeleton
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79936 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
2148e6b4a4
commit
f89e640810
3
mypy.ini
3
mypy.ini
|
|
@ -61,6 +61,9 @@ ignore_missing_imports = True
|
||||||
[mypy-torch.ao.quantization.experimental.apot_utils]
|
[mypy-torch.ao.quantization.experimental.apot_utils]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-torch.ao.quantization.experimental.quantizer]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
#
|
#
|
||||||
# Files with various errors. Mostly real errors, possibly some false
|
# Files with various errors. Mostly real errors, possibly some false
|
||||||
# positives as well.
|
# positives as well.
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,13 @@
|
||||||
# Owner(s): ["oncall: quantization"]
|
# Owner(s): ["oncall: quantization"]
|
||||||
|
|
||||||
import torch
|
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||||
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
class TestQuantizedTensor(unittest.TestCase):
|
class TestQuantizedTensor(unittest.TestCase):
|
||||||
def test_quantize_APoT(self):
|
def test_int_repr(self):
|
||||||
t = torch.Tensor()
|
quantizer = APoTQuantizer()
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
TensorAPoT.quantize_APoT(t)
|
quantizer.int_repr()
|
||||||
|
|
||||||
def test_dequantize(self):
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
TensorAPoT.dequantize(self)
|
|
||||||
|
|
||||||
def test_q_apot_alpha(self):
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
TensorAPoT.q_apot_alpha(self)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
22
test/quantization/core/experimental/test_quantizer.py
Normal file
22
test/quantization/core/experimental/test_quantizer.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Owner(s): ["oncall: quantization"]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
class TestQuantizer(unittest.TestCase):
|
||||||
|
def test_quantize_APoT(self):
|
||||||
|
t = torch.Tensor()
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
APoTQuantizer.quantize_APoT(t)
|
||||||
|
|
||||||
|
def test_dequantize(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
APoTQuantizer.dequantize(self)
|
||||||
|
|
||||||
|
def test_q_apot_alpha(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
APoTQuantizer.q_apot_alpha(self)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
@ -1,14 +1,12 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||||
|
|
||||||
# class to store APoT quantized tensor
|
# class to store APoT quantized tensor
|
||||||
class TensorAPoT(torch.Tensor):
|
class TensorAPoT(torch.Tensor):
|
||||||
@staticmethod
|
quantizer: APoTQuantizer
|
||||||
def quantize_APoT(tensor2quantize: Tensor) -> Tensor:
|
|
||||||
|
def __init__(self, quantizer):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def dequantize(self) -> Tensor:
|
def int_repr(self):
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def q_apot_alpha(self) -> float:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
15
torch/ao/quantization/experimental/quantizer.py
Normal file
15
torch/ao/quantization/experimental/quantizer.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
# class to store APoT quantizer
|
||||||
|
# implements quantize and dequantize
|
||||||
|
# and stores all quantization parameters
|
||||||
|
class APoTQuantizer():
|
||||||
|
@staticmethod
|
||||||
|
def quantize_APoT(tensor2quantize: Tensor) -> Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def dequantize(self) -> Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def q_apot_alpha(self) -> float:
|
||||||
|
raise NotImplementedError
|
||||||
Loading…
Reference in New Issue
Block a user