mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[quant] Add quantizer class skeleton
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79936 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
2148e6b4a4
commit
f89e640810
3
mypy.ini
3
mypy.ini
|
|
@ -61,6 +61,9 @@ ignore_missing_imports = True
|
|||
[mypy-torch.ao.quantization.experimental.apot_utils]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.ao.quantization.experimental.quantizer]
|
||||
ignore_missing_imports = True
|
||||
|
||||
#
|
||||
# Files with various errors. Mostly real errors, possibly some false
|
||||
# positives as well.
|
||||
|
|
|
|||
|
|
@ -1,22 +1,13 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
||||
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||
import unittest
|
||||
|
||||
class TestQuantizedTensor(unittest.TestCase):
|
||||
def test_quantize_APoT(self):
|
||||
t = torch.Tensor()
|
||||
def test_int_repr(self):
|
||||
quantizer = APoTQuantizer()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
TensorAPoT.quantize_APoT(t)
|
||||
|
||||
def test_dequantize(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
TensorAPoT.dequantize(self)
|
||||
|
||||
def test_q_apot_alpha(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
TensorAPoT.q_apot_alpha(self)
|
||||
quantizer.int_repr()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
22
test/quantization/core/experimental/test_quantizer.py
Normal file
22
test/quantization/core/experimental/test_quantizer.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||
import unittest
|
||||
|
||||
class TestQuantizer(unittest.TestCase):
|
||||
def test_quantize_APoT(self):
|
||||
t = torch.Tensor()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
APoTQuantizer.quantize_APoT(t)
|
||||
|
||||
def test_dequantize(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
APoTQuantizer.dequantize(self)
|
||||
|
||||
def test_q_apot_alpha(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
APoTQuantizer.q_apot_alpha(self)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,14 +1,12 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||
|
||||
# class to store APoT quantized tensor
|
||||
class TensorAPoT(torch.Tensor):
|
||||
@staticmethod
|
||||
def quantize_APoT(tensor2quantize: Tensor) -> Tensor:
|
||||
quantizer: APoTQuantizer
|
||||
|
||||
def __init__(self, quantizer):
|
||||
raise NotImplementedError
|
||||
|
||||
def dequantize(self) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def q_apot_alpha(self) -> float:
|
||||
def int_repr(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
15
torch/ao/quantization/experimental/quantizer.py
Normal file
15
torch/ao/quantization/experimental/quantizer.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from torch import Tensor
|
||||
|
||||
# class to store APoT quantizer
|
||||
# implements quantize and dequantize
|
||||
# and stores all quantization parameters
|
||||
class APoTQuantizer():
|
||||
@staticmethod
|
||||
def quantize_APoT(tensor2quantize: Tensor) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def dequantize(self) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def q_apot_alpha(self) -> float:
|
||||
raise NotImplementedError
|
||||
Loading…
Reference in New Issue
Block a user