mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
type check for torch.quantization.observer (#45630)
Summary: add type checker for observer Pull Request resolved: https://github.com/pytorch/pytorch/pull/45630 Reviewed By: malfet Differential Revision: D24058304 Pulled By: walterddr fbshipit-source-id: ac1c0f5ff0d34b0445bd1364653fc5c9d7571b05
This commit is contained in:
parent
db8b076272
commit
322855e380
3
mypy.ini
3
mypy.ini
|
|
@ -68,9 +68,6 @@ ignore_errors = True
|
|||
[mypy-torch.testing._internal.distributed.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization.observer]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization.stubs]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from functools import partial
|
||||
from typing import List, Tuple, Optional, Dict, Union
|
||||
from typing import Any, List, Tuple, Optional, Dict, Union
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -38,7 +38,7 @@ def _with_args(cls_or_self, **kwargs):
|
|||
return r
|
||||
|
||||
|
||||
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
|
||||
ABC: Any = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
|
||||
|
||||
|
||||
class ObserverBase(ABC, nn.Module):
|
||||
|
|
@ -111,6 +111,8 @@ class _ObserverBase(ObserverBase):
|
|||
# min_val and max_val buffers from torch.Size([0]) to torch.Size([])
|
||||
_version = 2
|
||||
|
||||
eps: torch.Tensor
|
||||
|
||||
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
||||
reduce_range=False, quant_min=None, quant_max=None):
|
||||
super(_ObserverBase, self).__init__(dtype=dtype)
|
||||
|
|
@ -155,8 +157,7 @@ class _ObserverBase(ObserverBase):
|
|||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
@torch.jit.export
|
||||
def _validate_qmin_qmax(self, quant_min, quant_max):
|
||||
# type: (int, int) -> None
|
||||
def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
|
||||
r"""Validates that the user-specified quantization range is properly initialized
|
||||
and within the given bound supported by the observer dtype.
|
||||
|
||||
|
|
@ -176,8 +177,7 @@ class _ObserverBase(ObserverBase):
|
|||
assert quant_min < quant_max, "qmin must be strictly less than qmax for user-specified quantization range."
|
||||
|
||||
@torch.jit.export
|
||||
def _calculate_qmin_qmax(self):
|
||||
# type: () -> Tuple[int, int]
|
||||
def _calculate_qmin_qmax(self) -> Tuple[int, int]:
|
||||
r"""Calculates actual qmin and qmax based on the quantization range,
|
||||
observer datatype and if range is reduced.
|
||||
"""
|
||||
|
|
@ -216,8 +216,7 @@ class _ObserverBase(ObserverBase):
|
|||
return quant_min, quant_max
|
||||
|
||||
@torch.jit.export
|
||||
def _calculate_qparams(self, min_val, max_val):
|
||||
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
||||
def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""Calculates the quantization parameters, given min and max
|
||||
value tensors. Works for both per tensor and per channel cases
|
||||
|
||||
|
|
@ -362,6 +361,8 @@ class MinMaxObserver(_ObserverBase):
|
|||
.. note:: If the running minimum equals to the running maximum, the scale
|
||||
and zero_point are set to 1.0 and 0.
|
||||
"""
|
||||
min_val: torch.Tensor
|
||||
max_val: torch.Tensor
|
||||
|
||||
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
||||
reduce_range=False, quant_min=None, quant_max=None):
|
||||
|
|
@ -501,6 +502,9 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
.. note:: If the running minimum equals to the running maximum, the scales
|
||||
and zero_points are set to 1.0 and 0.
|
||||
"""
|
||||
min_vals: torch.Tensor
|
||||
max_vals: torch.Tensor
|
||||
|
||||
|
||||
def __init__(self, ch_axis=0, dtype=torch.quint8,
|
||||
qscheme=torch.per_channel_affine, reduce_range=False,
|
||||
|
|
@ -679,6 +683,9 @@ class HistogramObserver(_ObserverBase):
|
|||
3. Compute the scale and zero point the same way as in the
|
||||
:class:`~torch.quantization.MinMaxObserver`
|
||||
"""
|
||||
histogram: torch.Tensor
|
||||
min_val: torch.Tensor
|
||||
max_val: torch.Tensor
|
||||
|
||||
def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
|
||||
qscheme=torch.per_tensor_affine, reduce_range=False):
|
||||
|
|
@ -821,8 +828,10 @@ class HistogramObserver(_ObserverBase):
|
|||
return new_min, new_max
|
||||
|
||||
@torch.jit.ignore
|
||||
def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
|
||||
# type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor, int, int]
|
||||
def _adjust_min_max(self,
|
||||
combined_min: torch.Tensor,
|
||||
combined_max: torch.Tensor,
|
||||
upsample_rate: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
|
||||
# We ensure that:
|
||||
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
|
||||
# This allows us to have a common grid of resolution s, where we can align
|
||||
|
|
@ -830,17 +839,22 @@ class HistogramObserver(_ObserverBase):
|
|||
# start_idx maps min_val to the histogram bin index.
|
||||
|
||||
hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate)
|
||||
downsample_rate = torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).to(torch.int).item()
|
||||
downsample_rate = int(torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).item())
|
||||
e = downsample_rate * (self.bins * hist_bin_width) - (combined_max - combined_min)
|
||||
# Relax only the max, not the min, so that for one sided distributions, min stays at zero
|
||||
combined_max = combined_max + e
|
||||
combined_min = combined_min
|
||||
start_idx = torch.round((self.min_val - combined_min) / hist_bin_width).to(torch.int).item()
|
||||
start_idx = int(torch.round((self.min_val - combined_min) / hist_bin_width).item())
|
||||
return combined_min, combined_max, downsample_rate, start_idx
|
||||
|
||||
@torch.jit.ignore
|
||||
def _combine_histograms(self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins):
|
||||
# type: (Tensor, Tensor, int, int, int, int) -> Tensor
|
||||
def _combine_histograms(self,
|
||||
orig_hist: torch.Tensor,
|
||||
new_hist: torch.Tensor,
|
||||
upsample_rate: int,
|
||||
downsample_rate: int,
|
||||
start_idx: int,
|
||||
Nbins: int) -> torch.Tensor:
|
||||
# First up-sample the histogram with new data by a factor of L
|
||||
# This creates an approximate probability density thats piecwise constant
|
||||
upsampled_histogram = new_hist.repeat_interleave(upsample_rate)
|
||||
|
|
@ -862,7 +876,7 @@ class HistogramObserver(_ObserverBase):
|
|||
return orig_hist
|
||||
|
||||
def forward(self, x_orig):
|
||||
# type: (Tensor) -> Tensor
|
||||
# type: (torch.Tensor) -> torch.Tensor
|
||||
x = x_orig.detach()
|
||||
min_val = self.min_val
|
||||
max_val = self.max_val
|
||||
|
|
@ -874,7 +888,10 @@ class HistogramObserver(_ObserverBase):
|
|||
self.min_val.copy_(min_val)
|
||||
self.max_val.resize_(max_val.shape)
|
||||
self.max_val.copy_(max_val)
|
||||
torch.histc(x, self.bins, min=min_val, max=max_val, out=self.histogram)
|
||||
assert min_val.numel() == 1 and max_val.numel() == 1, (
|
||||
"histogram min/max values must be scalar."
|
||||
)
|
||||
torch.histc(x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram)
|
||||
else:
|
||||
new_min, new_max = torch._aminmax(x)
|
||||
combined_min = torch.min(new_min, min_val)
|
||||
|
|
@ -884,7 +901,10 @@ class HistogramObserver(_ObserverBase):
|
|||
# and then downsampling the histogram efficiently
|
||||
combined_min, combined_max, downsample_rate, start_idx = \
|
||||
self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
|
||||
combined_histogram = torch.histc(x, self.bins, min=combined_min, max=combined_max)
|
||||
assert combined_min.numel() == 1 and combined_max.numel() == 1, (
|
||||
"histogram min/max values must be scalar."
|
||||
)
|
||||
combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max))
|
||||
if combined_min == min_val and combined_max == max_val:
|
||||
combined_histogram += self.histogram
|
||||
else:
|
||||
|
|
@ -1075,7 +1095,7 @@ def get_observer_state_dict(mod):
|
|||
for k, v in mod.state_dict().items():
|
||||
if 'activation_post_process' in k:
|
||||
od[k] = v
|
||||
od._metadata = mod.state_dict()._metadata
|
||||
od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined]
|
||||
return od
|
||||
|
||||
def load_observer_state_dict(mod, obs_dict):
|
||||
|
|
@ -1084,8 +1104,8 @@ def load_observer_state_dict(mod, obs_dict):
|
|||
load the stats back into the model. The observer state_dict can be saved
|
||||
using torch.quantization.get_observer_state_dict
|
||||
"""
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
missing_keys: List[str] = []
|
||||
unexpected_keys: List[str] = []
|
||||
for name, module in mod.named_modules():
|
||||
prefix = name + '.'
|
||||
if _is_activation_post_process(module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user