pytorch/torch/quantization/observer.py
Raghuraman Krishnamoorthi 1c5e48bbd0 Observer returns original tensor for post training quantization (#24196)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24196

Observer returns output with no changes for post training quant. This unifies observer semantics for QAT and PTQ.
ghstack-source-id: 88140887

Differential Revision: D16768277

fbshipit-source-id: fae7c94e3dc0eeda363e9982b3865a15113e11bd
2019-08-13 14:01:37 -07:00

77 lines
3.2 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
import torch
from functools import partial
class Observer(nn.Module):
r"""Default Observer Module
A default implementation of the observer module, only works for
`per_tensor_affine` quantization scheme.
The module will record the running average of max and min value of the
observed Tensor and calulate_qparams will calculate the scale and zero_point
Other types of Observers should follow the same API, it can take arbitrary
number of keyward arguments. In forward, it will update the statistics of
the observed Tensor. And it should provide a `calculate_qparam` function
that computes the quantization parameters given the collected statistics.
TODO: Maybe add an abstract Observer class that enforces these rules?
"""
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine):
super(Observer, self).__init__()
self.dtype = dtype
self.qscheme = qscheme
assert self.qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric), \
'Default Observer only works for per_tensor_affine and \
per_tensor_symmetric quantization scheme'
assert self.dtype in (torch.qint8, torch.quint8), \
'Default Observer only works for qint8 and quint data type'
self.min_val = None
self.max_val = None
def forward(self, x):
if self.min_val is None or self.max_val is None:
self.min_val = torch.min(x)
self.max_val = torch.max(x)
else:
self.min_val = torch.min(torch.min(x), self.min_val)
self.max_val = torch.max(torch.max(x), self.max_val)
return x
def calculate_qparams(self):
if self.dtype == torch.qint8:
qmin, qmax = -128, 127
else:
qmin, qmax = 0, 255
n_levels = 255.0
if self.max_val is None or self.min_val is None:
raise Exception('must run observer before calling calculate_qparams!')
max_val, min_val = self.max_val.item(), self.min_val.item()
if max_val == min_val:
scale = 1.0
zero_point = 0
else:
if self.qscheme == torch.per_tensor_symmetric:
max_val = max(-min_val, max_val)
scale = max_val / 127.0
scale = max(scale, torch.finfo(torch.float32).eps)
zero_point = 0 if self.dtype == torch.qint8 else 128
else:
scale = (max_val - min_val) / n_levels
scale = max(scale, torch.finfo(torch.float32).eps)
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
return torch.tensor([scale, zero_point])
def observer(observer_cls, **kwargs):
return partial(observer_cls, **kwargs)
def default_observer(**kwargs):
return observer(Observer, **kwargs)
def default_weight_observer(**kwargs):
kwargs.setdefault('dtype', torch.qint8)
kwargs.setdefault('qscheme', torch.per_tensor_symmetric)
return observer(Observer, **kwargs)