mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[quant] Implement quantize APoT method
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79499 Approved by: https://github.com/dzdang, https://github.com/jerryzh168
This commit is contained in:
parent
f89e640810
commit
d6ec8398a9
3
mypy.ini
3
mypy.ini
|
|
@ -64,6 +64,9 @@ ignore_missing_imports = True
|
|||
[mypy-torch.ao.quantization.experimental.quantizer]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.ao.quantization.experimental.observer]
|
||||
ignore_missing_imports = True
|
||||
|
||||
#
|
||||
# Files with various errors. Mostly real errors, possibly some false
|
||||
# positives as well.
|
||||
|
|
|
|||
|
|
@ -1,14 +1,76 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
from torch import quantize_per_tensor
|
||||
from torch.ao.quantization.experimental.quantizer import APoTQuantizer
|
||||
import unittest
|
||||
import random
|
||||
|
||||
class TestQuantizer(unittest.TestCase):
|
||||
def test_quantize_APoT(self):
|
||||
t = torch.Tensor()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
APoTQuantizer.quantize_APoT(t)
|
||||
r""" Tests quantize_APoT result on random 1-dim tensor
|
||||
and hardcoded values for b, k by comparing to uniform quantization
|
||||
(non-uniform quantization reduces to uniform for k = 1)
|
||||
quantized tensor (https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html)
|
||||
* tensor2quantize: Tensor
|
||||
* b: 4
|
||||
* k: 1
|
||||
"""
|
||||
def test_quantize_APoT_rand_k1(self):
|
||||
# generate random size of tensor2quantize between 1 -> 20
|
||||
size = random.randint(1, 20)
|
||||
|
||||
# generate tensor with random fp values between 0 -> 1000
|
||||
tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)
|
||||
|
||||
quantizer = APoTQuantizer(4, 1, torch.max(tensor2quantize), False)
|
||||
|
||||
# get apot quantized tensor result
|
||||
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
|
||||
|
||||
# get uniform quantization quantized tensor result
|
||||
uniform_quantized = quantize_per_tensor(input=tensor2quantize, scale=1.0, zero_point=0, dtype=torch.quint8).int_repr()
|
||||
|
||||
qtensor_data = torch.tensor(qtensor).type(torch.uint8)
|
||||
uniform_quantized_tensor = uniform_quantized.data
|
||||
|
||||
self.assertTrue(torch.equal(qtensor_data, uniform_quantized_tensor))
|
||||
|
||||
r""" Tests quantize_APoT for k != 1.
|
||||
Tests quantize_APoT result on random 1-dim tensor and hardcoded values for
|
||||
b=4, k=2 by comparing results to hand-calculated results from APoT paper
|
||||
https://arxiv.org/pdf/1909.13144.pdf
|
||||
* tensor2quantize: Tensor
|
||||
* b: 4
|
||||
* k: 2
|
||||
"""
|
||||
def test_quantize_APoT_k2(self):
|
||||
r"""
|
||||
given b = 4, k = 2, alpha = 1.0, we know:
|
||||
(from APoT paper example: https://arxiv.org/pdf/1909.13144.pdf)
|
||||
|
||||
quantization_levels = tensor([0.0000, 0.0208, 0.0417, 0.0625, 0.0833, 0.1250, 0.1667,
|
||||
0.1875, 0.2500, 0.3333, 0.3750, 0.5000, 0.6667, 0.6875, 0.7500, 1.0000])
|
||||
|
||||
level_indices = tensor([ 0, 3, 12, 15, 2, 14, 8, 11, 10, 1, 13, 9, 4, 7, 6, 5]))
|
||||
"""
|
||||
|
||||
# generate tensor with random fp values between 0 -> 1000
|
||||
tensor2quantize = torch.tensor([0.0215, 0.1692, 0.385, 0.0391])
|
||||
|
||||
quantizer = APoTQuantizer(4, 2, 1.0, False)
|
||||
|
||||
# get apot quantized tensor result
|
||||
qtensor = quantizer.quantize_APoT(tensor2quantize=tensor2quantize)
|
||||
qtensor_data = torch.tensor(qtensor).type(torch.uint8)
|
||||
|
||||
# expected qtensor values calculated based on
|
||||
# corresponding level_indices to nearest quantization level
|
||||
# for each fp value in tensor2quantize
|
||||
# e.g.
|
||||
# 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices
|
||||
expected_qtensor = torch.tensor([3, 8, 13, 12], dtype=torch.uint8)
|
||||
|
||||
self.assertTrue(torch.equal(qtensor_data, expected_qtensor))
|
||||
|
||||
def test_dequantize(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,55 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from torch.ao.quantization.experimental.observer import APoTObserver
|
||||
from torch.ao.quantization.experimental.apot_utils import float_to_apot
|
||||
|
||||
# class to store APoT quantizer
|
||||
# implements quantize and dequantize
|
||||
# and stores all quantization parameters
|
||||
class APoTQuantizer():
|
||||
@staticmethod
|
||||
def quantize_APoT(tensor2quantize: Tensor) -> Tensor:
|
||||
raise NotImplementedError
|
||||
b: int
|
||||
k: int
|
||||
n: int
|
||||
signed: bool
|
||||
quantization_levels: torch.Tensor
|
||||
level_indices: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
b,
|
||||
k,
|
||||
max_val,
|
||||
signed,
|
||||
dtype=torch.quint8) -> None:
|
||||
self.signed = signed
|
||||
|
||||
# check for valid inputs of b, k
|
||||
assert(k and k != 0)
|
||||
assert(b % k == 0)
|
||||
self.b = b
|
||||
self.k = k
|
||||
self.n = b // k
|
||||
|
||||
# make observer, get quantizion levels and level indices
|
||||
obs = APoTObserver(max_val=max_val, b=b, k=k)
|
||||
obs_result = obs.calculate_qparams(signed=signed)
|
||||
self.quantization_levels = obs_result[1]
|
||||
self.level_indices = obs_result[2]
|
||||
|
||||
r""" Quantizes fp Tensor to integer APoT representatio.
|
||||
Conversion is based on the calculated quantization levels from a specified APoT non-uniform observer.
|
||||
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
|
||||
Args:
|
||||
tensor2quantize: fp Tensor
|
||||
Returns:
|
||||
result: integer APoT representation of tensor2quantize
|
||||
"""
|
||||
def quantize_APoT(self, tensor2quantize: Tensor):
|
||||
result = torch.tensor([])
|
||||
# map float_to_apot over tensor2quantize elements
|
||||
result = tensor2quantize.apply_(lambda x: float_to_apot(x, self.quantization_levels, self.level_indices))
|
||||
|
||||
return result
|
||||
|
||||
def dequantize(self) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user