type check for torch.quantization.observer (#45630)

Summary:
add type checker for observer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45630

Reviewed By: malfet

Differential Revision: D24058304

Pulled By: walterddr

fbshipit-source-id: ac1c0f5ff0d34b0445bd1364653fc5c9d7571b05
This commit is contained in:
Rong Rong 2020-10-02 12:09:16 -07:00 committed by Facebook GitHub Bot
parent db8b076272
commit 322855e380
2 changed files with 40 additions and 23 deletions

View File

@ -68,9 +68,6 @@ ignore_errors = True
[mypy-torch.testing._internal.distributed.*]
ignore_errors = True
[mypy-torch.quantization.observer]
ignore_errors = True
[mypy-torch.quantization.stubs]
ignore_errors = True

View File

@ -2,7 +2,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from typing import List, Tuple, Optional, Dict, Union
from typing import Any, List, Tuple, Optional, Dict, Union
from collections import OrderedDict
import torch
import torch.nn as nn
@ -38,7 +38,7 @@ def _with_args(cls_or_self, **kwargs):
return r
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
ABC: Any = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
class ObserverBase(ABC, nn.Module):
@ -111,6 +111,8 @@ class _ObserverBase(ObserverBase):
# min_val and max_val buffers from torch.Size([0]) to torch.Size([])
_version = 2
eps: torch.Tensor
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=False, quant_min=None, quant_max=None):
super(_ObserverBase, self).__init__(dtype=dtype)
@ -155,8 +157,7 @@ class _ObserverBase(ObserverBase):
missing_keys, unexpected_keys, error_msgs)
@torch.jit.export
def _validate_qmin_qmax(self, quant_min, quant_max):
# type: (int, int) -> None
def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
r"""Validates that the user-specified quantization range is properly initialized
and within the given bound supported by the observer dtype.
@ -176,8 +177,7 @@ class _ObserverBase(ObserverBase):
assert quant_min < quant_max, "qmin must be strictly less than qmax for user-specified quantization range."
@torch.jit.export
def _calculate_qmin_qmax(self):
# type: () -> Tuple[int, int]
def _calculate_qmin_qmax(self) -> Tuple[int, int]:
r"""Calculates actual qmin and qmax based on the quantization range,
observer datatype and if range is reduced.
"""
@ -216,8 +216,7 @@ class _ObserverBase(ObserverBase):
return quant_min, quant_max
@torch.jit.export
def _calculate_qparams(self, min_val, max_val):
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Calculates the quantization parameters, given min and max
value tensors. Works for both per tensor and per channel cases
@ -362,6 +361,8 @@ class MinMaxObserver(_ObserverBase):
.. note:: If the running minimum equals to the running maximum, the scale
and zero_point are set to 1.0 and 0.
"""
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=False, quant_min=None, quant_max=None):
@ -501,6 +502,9 @@ class PerChannelMinMaxObserver(_ObserverBase):
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
min_vals: torch.Tensor
max_vals: torch.Tensor
def __init__(self, ch_axis=0, dtype=torch.quint8,
qscheme=torch.per_channel_affine, reduce_range=False,
@ -679,6 +683,9 @@ class HistogramObserver(_ObserverBase):
3. Compute the scale and zero point the same way as in the
:class:`~torch.quantization.MinMaxObserver`
"""
histogram: torch.Tensor
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
qscheme=torch.per_tensor_affine, reduce_range=False):
@ -821,8 +828,10 @@ class HistogramObserver(_ObserverBase):
return new_min, new_max
@torch.jit.ignore
def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
# type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor, int, int]
def _adjust_min_max(self,
combined_min: torch.Tensor,
combined_max: torch.Tensor,
upsample_rate: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
# We ensure that:
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
# This allows us to have a common grid of resolution s, where we can align
@ -830,17 +839,22 @@ class HistogramObserver(_ObserverBase):
# start_idx maps min_val to the histogram bin index.
hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate)
downsample_rate = torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).to(torch.int).item()
downsample_rate = int(torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).item())
e = downsample_rate * (self.bins * hist_bin_width) - (combined_max - combined_min)
# Relax only the max, not the min, so that for one sided distributions, min stays at zero
combined_max = combined_max + e
combined_min = combined_min
start_idx = torch.round((self.min_val - combined_min) / hist_bin_width).to(torch.int).item()
start_idx = int(torch.round((self.min_val - combined_min) / hist_bin_width).item())
return combined_min, combined_max, downsample_rate, start_idx
@torch.jit.ignore
def _combine_histograms(self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins):
# type: (Tensor, Tensor, int, int, int, int) -> Tensor
def _combine_histograms(self,
orig_hist: torch.Tensor,
new_hist: torch.Tensor,
upsample_rate: int,
downsample_rate: int,
start_idx: int,
Nbins: int) -> torch.Tensor:
# First up-sample the histogram with new data by a factor of L
# This creates an approximate probability density thats piecwise constant
upsampled_histogram = new_hist.repeat_interleave(upsample_rate)
@ -862,7 +876,7 @@ class HistogramObserver(_ObserverBase):
return orig_hist
def forward(self, x_orig):
# type: (Tensor) -> Tensor
# type: (torch.Tensor) -> torch.Tensor
x = x_orig.detach()
min_val = self.min_val
max_val = self.max_val
@ -874,7 +888,10 @@ class HistogramObserver(_ObserverBase):
self.min_val.copy_(min_val)
self.max_val.resize_(max_val.shape)
self.max_val.copy_(max_val)
torch.histc(x, self.bins, min=min_val, max=max_val, out=self.histogram)
assert min_val.numel() == 1 and max_val.numel() == 1, (
"histogram min/max values must be scalar."
)
torch.histc(x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram)
else:
new_min, new_max = torch._aminmax(x)
combined_min = torch.min(new_min, min_val)
@ -884,7 +901,10 @@ class HistogramObserver(_ObserverBase):
# and then downsampling the histogram efficiently
combined_min, combined_max, downsample_rate, start_idx = \
self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
combined_histogram = torch.histc(x, self.bins, min=combined_min, max=combined_max)
assert combined_min.numel() == 1 and combined_max.numel() == 1, (
"histogram min/max values must be scalar."
)
combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max))
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram
else:
@ -1075,7 +1095,7 @@ def get_observer_state_dict(mod):
for k, v in mod.state_dict().items():
if 'activation_post_process' in k:
od[k] = v
od._metadata = mod.state_dict()._metadata
od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined]
return od
def load_observer_state_dict(mod, obs_dict):
@ -1084,8 +1104,8 @@ def load_observer_state_dict(mod, obs_dict):
load the stats back into the model. The observer state_dict can be saved
using torch.quantization.get_observer_state_dict
"""
missing_keys = []
unexpected_keys = []
missing_keys: List[str] = []
unexpected_keys: List[str] = []
for name, module in mod.named_modules():
prefix = name + '.'
if _is_activation_post_process(module):