[quant] Implement quantize APoT method

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79499

Approved by: https://github.com/dzdang, https://github.com/jerryzh168
This commit is contained in:
asl3 2022-06-21 19:29:25 -07:00 committed by PyTorch MergeBot
parent f89e640810
commit d6ec8398a9
3 changed files with 115 additions and 7 deletions

View File

@ -64,6 +64,9 @@ ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.quantizer] [mypy-torch.ao.quantization.experimental.quantizer]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.observer]
ignore_missing_imports = True
# #
# Files with various errors. Mostly real errors, possibly some false # Files with various errors. Mostly real errors, possibly some false
# positives as well. # positives as well.

View File

@ -1,14 +1,76 @@
# Owner(s): ["oncall: quantization"] # Owner(s): ["oncall: quantization"]
import torch import torch
from torch import quantize_per_tensor
from torch.ao.quantization.experimental.quantizer import APoTQuantizer from torch.ao.quantization.experimental.quantizer import APoTQuantizer
import unittest import unittest
import random
class TestQuantizer(unittest.TestCase): class TestQuantizer(unittest.TestCase):
def test_quantize_APoT(self): r""" Tests quantize_APoT result on random 1-dim tensor
t = torch.Tensor() and hardcoded values for b, k by comparing to uniform quantization
with self.assertRaises(NotImplementedError): (non-uniform quantization reduces to uniform for k = 1)
APoTQuantizer.quantize_APoT(t) quantized tensor (https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html)
* tensor2quantize: Tensor
* b: 4
* k: 1
"""
def test_quantize_APoT_rand_k1(self):
# generate random size of tensor2quantize between 1 -> 20
size = random.randint(1, 20)
# generate tensor with random fp values between 0 -> 1000
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
quantizer = APoTQuantizer(4, 1, torch.max(tensor2quantize), False)
# get apot quantized tensor result
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
# get uniform quantization quantized tensor result
uniform_quantized = quantize_per_tensor(input=tensor2quantize, scale=1.0, zero_point=0, dtype=torch.quint8).int_repr()
qtensor_data = torch.tensor(qtensor).type(torch.uint8)
uniform_quantized_tensor = uniform_quantized.data
self.assertTrue(torch.equal(qtensor_data, uniform_quantized_tensor))
r""" Tests quantize_APoT for k != 1.
Tests quantize_APoT result on random 1-dim tensor and hardcoded values for
b=4, k=2 by comparing results to hand-calculated results from APoT paper
https://arxiv.org/pdf/1909.13144.pdf
* tensor2quantize: Tensor
* b: 4
* k: 2
"""
def test_quantize_APoT_k2(self):
r"""
given b = 4, k = 2, alpha = 1.0, we know:
(from APoT paper example: https://arxiv.org/pdf/1909.13144.pdf)
quantization_levels = tensor([0.0000, 0.0208, 0.0417, 0.0625, 0.0833, 0.1250, 0.1667,
0.1875, 0.2500, 0.3333, 0.3750, 0.5000, 0.6667, 0.6875, 0.7500, 1.0000])
level_indices = tensor([ 0, 3, 12, 15, 2, 14, 8, 11, 10, 1, 13, 9, 4, 7, 6, 5]))
"""
# generate tensor with random fp values between 0 -> 1000
tensor2quantize = torch.tensor([0.0215, 0.1692, 0.385, 0.0391])
quantizer = APoTQuantizer(4, 2, 1.0, False)
# get apot quantized tensor result
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
qtensor_data = torch.tensor(qtensor).type(torch.uint8)
# expected qtensor values calculated based on
# corresponding level_indices to nearest quantization level
# for each fp value in tensor2quantize
# e.g.
# 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices
expected_qtensor = torch.tensor([3, 8, 13, 12], dtype=torch.uint8)
self.assertTrue(torch.equal(qtensor_data, expected_qtensor))
def test_dequantize(self): def test_dequantize(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):

View File

@ -1,12 +1,55 @@
import torch
from torch import Tensor from torch import Tensor
from torch.ao.quantization.experimental.observer import APoTObserver
from torch.ao.quantization.experimental.apot_utils import float_to_apot
# class to store APoT quantizer # class to store APoT quantizer
# implements quantize and dequantize # implements quantize and dequantize
# and stores all quantization parameters # and stores all quantization parameters
class APoTQuantizer(): class APoTQuantizer():
@staticmethod b: int
def quantize_APoT(tensor2quantize: Tensor) -> Tensor: k: int
raise NotImplementedError n: int
signed: bool
quantization_levels: torch.Tensor
level_indices: torch.Tensor
def __init__(
self,
b,
k,
max_val,
signed,
dtype=torch.quint8) -> None:
self.signed = signed
# check for valid inputs of b, k
assert(k and k != 0)
assert(b % k == 0)
self.b = b
self.k = k
self.n = b // k
# make observer, get quantizion levels and level indices
obs = APoTObserver(max_val=max_val, b=b, k=k)
obs_result = obs.calculate_qparams(signed=signed)
self.quantization_levels = obs_result[1]
self.level_indices = obs_result[2]
r""" Quantizes fp Tensor to integer APoT representatio.
Conversion is based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
tensor2quantize: fp Tensor
Returns:
result: integer APoT representation of tensor2quantize
"""
def quantize_APoT(self, tensor2quantize: Tensor):
result = torch.tensor([])
# map float_to_apot over tensor2quantize elements
result = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices))
return result
def dequantize(self) -> Tensor: def dequantize(self) -> Tensor:
raise NotImplementedError raise NotImplementedError