mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65443 Test Plan: Imported from OSS Reviewed By: dagitses, supriyar Differential Revision: D31456445 Pulled By: b-koopman fbshipit-source-id: 0edda6e272d9005fce65f2ba6a5e6abc831836de
405 lines
18 KiB
Python
405 lines
18 KiB
Python
import torch
|
|
from torch.nn import Module
|
|
from torch.ao.quantization.observer import (
|
|
MovingAverageMinMaxObserver,
|
|
HistogramObserver,
|
|
MovingAveragePerChannelMinMaxObserver,
|
|
PerChannelMinMaxObserver,
|
|
_with_args,
|
|
)
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Tuple
|
|
|
|
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
|
|
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
|
|
|
|
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
|
|
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
|
|
|
|
def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
|
|
return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
|
|
|
|
class FakeQuantizeBase(ABC, Module):
|
|
r""" Base fake quantize module
|
|
Any fake quantize implementation should derive from this class.
|
|
|
|
Concrete fake quantize module should follow the same API. In forward, they will update
|
|
the statistics of the observed Tensor and fake quantize the input. They should also provide a
|
|
`calculate_qparams` function that computes the quantization parameters given
|
|
the collected statistics.
|
|
|
|
"""
|
|
|
|
fake_quant_enabled: torch.Tensor
|
|
observer_enabled: torch.Tensor
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
# fake_quant_enabled and observer_enabled are buffers to support their
|
|
# replication in DDP. Data type is uint8 because NCCL does not support
|
|
# bool tensors.
|
|
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
|
|
self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
|
|
|
|
@abstractmethod
|
|
def forward(self, x):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def calculate_qparams(self, **kwargs):
|
|
pass
|
|
|
|
@torch.jit.export
|
|
def enable_fake_quant(self, enabled: bool = True) -> None:
|
|
self.fake_quant_enabled[0] = 1 if enabled else 0
|
|
|
|
@torch.jit.export
|
|
def disable_fake_quant(self):
|
|
self.enable_fake_quant(False)
|
|
|
|
@torch.jit.export
|
|
def enable_observer(self, enabled: bool = True) -> None:
|
|
self.observer_enabled[0] = 1 if enabled else 0
|
|
|
|
@torch.jit.export
|
|
def disable_observer(self):
|
|
self.enable_observer(False)
|
|
|
|
with_args = classmethod(_with_args)
|
|
|
|
class FakeQuantize(FakeQuantizeBase):
|
|
r""" Simulate the quantize and dequantize operations in training time.
|
|
The output of this module is given by
|
|
|
|
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
|
|
|
|
|
|
|
|
* :attr:`scale` defines the scale factor used for quantization.
|
|
|
|
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
|
|
|
|
* :attr:`quant_min` specifies the minimum allowable quantized value.
|
|
|
|
* :attr:`quant_max` specifies the maximum allowable quantized value.
|
|
|
|
* :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
|
|
statistics can still be updated.
|
|
|
|
* :attr:`observer_enable` controls statistics collection on tensors
|
|
|
|
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
|
|
allowable values are torch.qint8 and torch.quint8. The values of quant_min and
|
|
quant_max should be chosen to be consistent with the dtype
|
|
|
|
|
|
Args:
|
|
observer (module): Module for observing statistics on input tensors and calculating scale
|
|
and zero-point.
|
|
quant_min (int): The minimum allowable quantized value.
|
|
quant_max (int): The maximum allowable quantized value.
|
|
observer_kwargs (optional): Arguments for the observer module
|
|
|
|
Attributes:
|
|
observer (Module): User provided module that collects statistics on the input tensor and
|
|
provides a method to calculate scale and zero-point.
|
|
|
|
"""
|
|
|
|
scale: torch.Tensor
|
|
zero_point: torch.Tensor
|
|
|
|
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
|
|
super().__init__()
|
|
assert quant_min <= quant_max, \
|
|
'quant_min must be less than or equal to quant_max'
|
|
self.quant_min = quant_min
|
|
self.quant_max = quant_max
|
|
self.activation_post_process = observer(**observer_kwargs)
|
|
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
|
|
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
|
|
self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
|
|
self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
|
|
self.dtype = self.activation_post_process.dtype
|
|
self.qscheme = self.activation_post_process.qscheme
|
|
self.ch_axis = self.activation_post_process.ch_axis \
|
|
if hasattr(self.activation_post_process, 'ch_axis') else -1
|
|
assert _is_per_channel(self.qscheme) or \
|
|
_is_per_tensor(self.qscheme), \
|
|
'Only per channel and per tensor quantization are supported in fake quantize' + \
|
|
' got qscheme: ' + str(self.qscheme)
|
|
self.is_per_channel = _is_per_channel(self.qscheme)
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
return self.activation_post_process.calculate_qparams()
|
|
|
|
def forward(self, X):
|
|
if self.observer_enabled[0] == 1:
|
|
self.activation_post_process(X.detach())
|
|
_scale, _zero_point = self.calculate_qparams()
|
|
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
|
|
if self.scale.shape != _scale.shape:
|
|
self.scale.resize_(_scale.shape)
|
|
self.zero_point.resize_(_zero_point.shape)
|
|
self.scale.copy_(_scale)
|
|
self.zero_point.copy_(_zero_point)
|
|
|
|
if self.fake_quant_enabled[0] == 1:
|
|
if self.is_per_channel:
|
|
X = torch.fake_quantize_per_channel_affine(
|
|
X, self.scale, self.zero_point,
|
|
self.ch_axis, self.quant_min, self.quant_max)
|
|
else:
|
|
X = torch.fake_quantize_per_tensor_affine(
|
|
X, self.scale, self.zero_point,
|
|
self.quant_min, self.quant_max)
|
|
return X
|
|
|
|
@torch.jit.export
|
|
def extra_repr(self):
|
|
return 'fake_quant_enabled={}, observer_enabled={}, ' \
|
|
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
|
|
'scale={}, zero_point={}'.format(
|
|
self.fake_quant_enabled, self.observer_enabled,
|
|
self.quant_min, self.quant_max,
|
|
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
# We cannot currently register scalar values as buffers, so need to manually
|
|
# specify serialization here.
|
|
super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + 'scale'] = self.scale
|
|
destination[prefix + 'zero_point'] = self.zero_point
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
# Removing this function throws an error that the the size of the loaded tensor does not match the original size
|
|
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
|
|
local_state = ['scale', 'zero_point']
|
|
for name in local_state:
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
val = state_dict[key]
|
|
# Custom handling to allow loading scale and zero_point
|
|
# of size N into uninitialized buffers of size 0. The
|
|
# buffers are resized here, and the values are copied in
|
|
# the default state_dict loading code of the parent.
|
|
if name == 'scale':
|
|
self.scale.resize_(val.shape)
|
|
else:
|
|
assert name == 'zero_point'
|
|
self.zero_point.resize_(val.shape)
|
|
# For torchscript module we need to update the attributes here since we do not
|
|
# call the `_load_from_state_dict` function defined module.py
|
|
if torch.jit.is_scripting():
|
|
if name == 'scale':
|
|
self.scale.copy_(val)
|
|
else:
|
|
assert name == 'zero_point'
|
|
self.zero_point.copy_(val)
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
class FixedQParamsFakeQuantize(FakeQuantizeBase):
|
|
""" Simulate quantize and dequantize with fixed quantization
|
|
parameters in training time. Only per tensor quantization
|
|
is supported.
|
|
Args:
|
|
`scale` (float): fixed scale for the fake quantize module
|
|
`zero_point` (int): fixed zero point for the fake quantize module
|
|
`dtype`, `qscheme`, `quant_min`, `quant_max`
|
|
"""
|
|
|
|
scale: torch.Tensor
|
|
zero_point: torch.Tensor
|
|
|
|
def __init__(self,
|
|
scale,
|
|
zero_point,
|
|
dtype=torch.quint8,
|
|
qscheme=torch.per_tensor_affine,
|
|
quant_min=0,
|
|
quant_max=255):
|
|
super().__init__()
|
|
assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max'
|
|
self.quant_min = quant_min
|
|
self.quant_max = quant_max
|
|
self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
|
|
self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
|
|
self.dtype = dtype
|
|
self.qscheme = qscheme
|
|
assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
|
|
' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
|
|
|
|
def forward(self, X):
|
|
if self.fake_quant_enabled[0] == 1:
|
|
X = torch.fake_quantize_per_tensor_affine(X, self.scale,
|
|
self.zero_point, self.quant_min,
|
|
self.quant_max)
|
|
return X
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
return self.scale, self.zero_point
|
|
|
|
@torch.jit.export
|
|
def extra_repr(self):
|
|
return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
|
|
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
|
|
self.fake_quant_enabled, self.observer_enabled,
|
|
self.scale, self.zero_point, self.dtype,
|
|
self.quant_min, self.quant_max, self.qscheme)
|
|
|
|
class FusedMovingAvgObsFakeQuantize(FakeQuantize):
|
|
r"""Fused module that is used to observe the input tensor (compute min/max), compute
|
|
scale/zero_point and fake_quantize the tensor.
|
|
This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
|
|
to compute the min/max values in order to compute the scale/zero_point.
|
|
The qscheme input in the observer is used to differentiate between symmetric/affine
|
|
quantization scheme.
|
|
|
|
The output of this module is given by
|
|
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
|
|
|
|
Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
|
|
base class.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observer: Any = MovingAverageMinMaxObserver,
|
|
quant_min: int = 0,
|
|
quant_max: int = 255,
|
|
**observer_kwargs: Any
|
|
) -> None:
|
|
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
|
|
assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\
|
|
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
|
|
self.quant_min: int = quant_min
|
|
self.quant_max: int = quant_max
|
|
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
|
|
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
|
|
self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
|
|
|
|
self.quant_min, self.quant_max = self.activation_post_process.quant_min, self.activation_post_process.quant_max
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
return self.activation_post_process.calculate_qparams()
|
|
|
|
@torch.jit.export
|
|
def extra_repr(self) -> str:
|
|
return (
|
|
"fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
|
|
"dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
|
|
self.fake_quant_enabled,
|
|
self.observer_enabled,
|
|
self.scale,
|
|
self.zero_point,
|
|
self.dtype,
|
|
self.quant_min,
|
|
self.quant_max,
|
|
self.qscheme,
|
|
self.activation_post_process.reduce_range,
|
|
)
|
|
)
|
|
|
|
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
return torch.fused_moving_avg_obs_fake_quant(
|
|
X,
|
|
self.observer_enabled,
|
|
self.fake_quant_enabled,
|
|
self.activation_post_process.min_val,
|
|
self.activation_post_process.max_val,
|
|
self.scale,
|
|
self.zero_point,
|
|
self.activation_post_process.averaging_constant,
|
|
self.quant_min,
|
|
self.quant_max,
|
|
self.ch_axis,
|
|
self.is_per_channel,
|
|
self.is_symmetric_quant,
|
|
)
|
|
|
|
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
|
|
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
|
|
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
|
|
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
|
|
|
|
# TODO(future PR): remove these defaults and enforce activation functions
|
|
# to explicitly specify their output range
|
|
default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
|
|
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
|
|
default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
|
|
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
|
|
|
|
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
dtype=torch.qint8,
|
|
qscheme=torch.per_channel_symmetric,
|
|
reduce_range=False,
|
|
ch_axis=0)
|
|
|
|
default_embedding_fake_quant = FakeQuantize.with_args(observer=PerChannelMinMaxObserver,
|
|
qscheme=torch.per_channel_affine_float_qparams,
|
|
ch_axis=0,
|
|
memoryless=True)
|
|
|
|
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
dtype=torch.quint8,
|
|
qscheme=torch.per_tensor_affine,
|
|
reduce_range=True)
|
|
|
|
default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
dtype=torch.quint8,)
|
|
|
|
|
|
default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
dtype=torch.qint8,
|
|
qscheme=torch.per_tensor_symmetric)
|
|
|
|
default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
dtype=torch.qint8,
|
|
qscheme=torch.per_channel_symmetric)
|
|
|
|
def _is_fake_quant_script_module(mod):
|
|
''' Returns true if given mod is an instance of FakeQuantize script module.
|
|
'''
|
|
if isinstance(mod, torch.jit.RecursiveScriptModule):
|
|
# qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
|
|
suffix = mod._c.qualified_name.split('.', 1)[1]
|
|
name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
|
|
return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \
|
|
name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
|
|
return False
|
|
|
|
def disable_fake_quant(mod):
|
|
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
|
mod.disable_fake_quant()
|
|
|
|
def enable_fake_quant(mod):
|
|
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
|
mod.enable_fake_quant()
|
|
|
|
def disable_observer(mod):
|
|
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
|
mod.disable_observer()
|
|
|
|
def enable_observer(mod):
|
|
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
|
mod.enable_observer()
|