pytorch/torch/quantization/fake_quantize.py
Supriya Rao 7a15576a65 [quant] update FakeQuant modules to use tensor qparams (#61318)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61318

Remove the `float()` and `int()` calls in the forward function so that we can directly use the tensor qparams in the fake_quantize operator.

Calling `float()/int()` internally calls `item()` which can trigger a gpu-> cpu copy if the original tensors reside on GPU.
Local benchmark P427668213

Before this change
```
                                               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
---------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                     aten::_aminmax         2.57%       1.507ms         3.10%       1.819ms      36.371us       2.872ms         4.81%       2.872ms      57.446us            50
              aten::fake_quantize_per_tensor_affine         1.04%     610.915us         3.60%       2.114ms      42.276us     472.896us         0.79%       2.698ms      53.962us            50
    aten::fake_quantize_per_tensor_affine_cachemask         1.69%     993.626us         2.56%       1.503ms      30.058us       2.225ms         3.73%       2.225ms      44.504us            50
                                   aten::is_nonzero         3.85%       2.258ms        19.68%      11.540ms      46.161us       2.168ms         3.63%      11.084ms      44.336us           250
                                   aten::zeros_like         1.82%       1.064ms         6.65%       3.901ms      39.007us       1.531ms         2.57%       3.905ms      39.045us           100
                                           aten::eq        13.80%       8.093ms        25.90%      15.189ms      37.972us       9.580ms        16.05%      15.566ms      38.914us           400
                                         aten::item         5.67%       3.323ms        21.50%      12.607ms      36.019us       3.233ms         5.42%      12.167ms      34.762us           350
                                        aten::zeros         0.94%     549.208us         2.93%       1.717ms      34.343us     688.928us         1.15%       1.695ms      33.894us            50
                                           aten::le         2.52%       1.478ms         4.50%       2.641ms      26.411us       1.753ms         2.94%       2.845ms      28.448us           100
                                         aten::rsub         1.04%     608.715us         2.44%       1.433ms      28.667us     532.000us         0.89%       1.418ms      28.353us            50
                                          aten::max         1.54%     905.401us         4.62%       2.711ms      27.106us     847.488us         1.42%       2.697ms      26.969us           100
                                         aten::ones         0.92%     542.159us         2.16%       1.266ms      25.324us     661.856us         1.11%       1.301ms      26.017us            50
                                          aten::min         0.82%     479.167us         2.15%       1.258ms      25.160us     407.808us         0.68%       1.276ms      25.530us            50
                          aten::_local_scalar_dense        15.83%       9.284ms        15.83%       9.284ms      26.526us       8.934ms        14.97%       8.934ms      25.524us           350
                                        aten::clamp         2.35%       1.378ms         4.21%       2.467ms      24.669us       1.546ms         2.59%       2.461ms      24.612us           100
                                        aten::zero_         2.53%       1.482ms         5.65%       3.316ms      22.108us       1.326ms         2.22%       3.380ms      22.531us           150
                                      aten::maximum         3.08%       1.805ms         3.08%       1.805ms      18.052us       1.849ms         3.10%       1.849ms      18.494us           100
                                      aten::minimum         1.33%     778.854us         1.33%     778.854us      15.577us     868.672us         1.46%     868.672us      17.373us            50
                                        aten::round         1.36%     799.910us         1.36%     799.910us      15.998us     809.568us         1.36%     809.568us      16.191us            50
                                        aten::copy_         6.61%       3.878ms         6.61%       3.878ms      15.513us       4.036ms         6.76%       4.036ms      16.143us           250
                                          aten::div         2.53%       1.483ms         2.53%       1.483ms      14.833us       1.535ms         2.57%       1.535ms      15.353us           100
                                          aten::mul         2.44%       1.431ms         2.44%       1.431ms      14.314us       1.478ms         2.48%       1.478ms      14.782us           100
                                       aten::detach         1.46%     855.670us         2.41%       1.411ms      14.110us     832.448us         1.39%       1.395ms      13.949us           100
                                          aten::add         2.22%       1.301ms         2.22%       1.301ms      13.008us       1.383ms         2.32%       1.383ms      13.828us           100
                                        aten::fill_         4.18%       2.452ms         4.18%       2.452ms      12.262us       2.693ms         4.51%       2.693ms      13.463us           200
                                          aten::sub         5.06%       2.967ms         5.06%       2.967ms      14.837us       2.675ms         4.48%       2.675ms      13.374us           200
                                           aten::to         2.10%       1.230ms         3.65%       2.140ms      10.701us       1.310ms         2.20%       2.062ms      10.310us           200
                                       aten::select         1.28%     749.144us         1.49%     874.227us       8.742us     863.232us         1.45%     863.232us       8.632us           100
                                             detach         0.95%     555.326us         0.95%     555.326us       5.553us     562.496us         0.94%     562.496us       5.625us           100
                                   aten::as_strided         0.40%     232.289us         0.40%     232.289us       1.161us       0.000us         0.00%       0.000us       0.000us           200
                                        aten::empty         2.93%       1.720ms         2.93%       1.720ms       3.439us       0.000us         0.00%       0.000us       0.000us           500
                                      aten::resize_         1.04%     611.313us         1.04%     611.313us       2.038us       0.000us         0.00%       0.000us       0.000us           300
                                   aten::empty_like         0.75%     438.585us         1.77%       1.036ms       5.180us       0.000us         0.00%       0.000us       0.000us           200
                                aten::empty_strided         1.36%     799.442us         1.36%     799.442us       3.198us       0.000us         0.00%       0.000us       0.000us           250
---------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 58.645ms
Self CUDA time total: 59.674ms
```

After this change
```

test_fake_quant_profiler (scripts.supriyar.benchmark.module_bench.ProfilerBench) ... -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::fake_quantize_per_tensor_affine         0.98%     505.210us         4.38%       2.259ms      45.187us     419.424us         0.78%       3.218ms      64.367us            50
                                         aten::_aminmax         2.78%       1.434ms         3.42%       1.766ms      35.321us       2.825ms         5.27%       2.825ms      56.505us            50
aten::fake_quantize_per_tensor_affine_cachemask_tens...         2.38%       1.229ms         3.40%       1.754ms      35.083us       2.799ms         5.22%       2.799ms      55.979us            50
                                             aten::rsub         0.94%     485.040us         5.02%       2.590ms      51.793us     458.976us         0.86%       2.587ms      51.747us            50
                                       aten::is_nonzero         3.78%       1.952ms        23.64%      12.196ms      48.786us       2.055ms         3.83%      11.986ms      47.944us           250
                                             aten::item         6.92%       3.572ms        19.86%      10.244ms      40.977us       3.670ms         6.85%       9.931ms      39.724us           250
                                       aten::zeros_like         1.65%     848.874us         6.64%       3.426ms      34.260us       1.397ms         2.61%       3.572ms      35.717us           100
                                            aten::zeros         0.85%     436.691us         3.00%       1.549ms      30.984us     551.936us         1.03%       1.576ms      31.516us            50
                                               aten::eq        10.60%       5.467ms        20.26%      10.452ms      26.130us       7.018ms        13.09%      10.832ms      27.079us           400
                                               aten::le         2.58%       1.332ms         4.67%       2.407ms      24.074us       1.580ms         2.95%       2.614ms      26.144us           100
                              aten::_local_scalar_dense        12.93%       6.673ms        12.93%       6.673ms      26.691us       6.261ms        11.68%       6.261ms      25.046us           250
                                            aten::clamp         2.43%       1.253ms         4.37%       2.256ms      22.560us       1.431ms         2.67%       2.273ms      22.725us           100
                                             aten::ones         0.89%     460.133us         2.18%       1.123ms      22.467us     570.496us         1.06%       1.128ms      22.551us            50
                                              aten::min         0.74%     383.132us         2.06%       1.065ms      21.296us     377.536us         0.70%       1.091ms      21.824us            50
                                            aten::zero_         2.36%       1.219ms         5.87%       3.029ms      20.194us       1.261ms         2.35%       3.199ms      21.327us           150
                                              aten::max         1.51%     779.081us         4.06%       2.096ms      20.960us     791.680us         1.48%       2.130ms      21.295us           100
                                              aten::sub         7.97%       4.111ms         7.97%       4.111ms      20.556us       3.847ms         7.18%       3.847ms      19.234us           200
                                              aten::div         2.94%       1.516ms         2.94%       1.516ms      15.158us       1.580ms         2.95%       1.580ms      15.798us           100
                                            aten::round         1.45%     750.445us         1.45%     750.445us      15.009us     756.064us         1.41%     756.064us      15.121us            50
                                            aten::copy_         6.88%       3.548ms         6.88%       3.548ms      14.190us       3.701ms         6.90%       3.701ms      14.803us           250
                                          aten::minimum         1.32%     681.654us         1.32%     681.654us      13.633us     713.664us         1.33%     713.664us      14.273us            50
                                          aten::maximum         2.55%       1.317ms         2.55%       1.317ms      13.169us       1.338ms         2.50%       1.338ms      13.378us           100
                                              aten::mul         2.63%       1.358ms         2.63%       1.358ms      13.581us       1.328ms         2.48%       1.328ms      13.283us           100
                                           aten::detach         1.34%     688.820us         2.35%       1.211ms      12.110us     772.800us         1.44%       1.278ms      12.779us           100
                                            aten::fill_         4.53%       2.338ms         4.53%       2.338ms      11.692us       2.495ms         4.65%       2.495ms      12.473us           200
                                              aten::add         2.32%       1.197ms         2.32%       1.197ms      11.968us       1.240ms         2.31%       1.240ms      12.405us           100
                                               aten::to         2.07%       1.069ms         3.66%       1.889ms       9.443us       1.224ms         2.28%       1.975ms       9.874us           200
                                           aten::select         1.44%     743.042us         1.64%     848.207us       8.482us     641.600us         1.20%     641.600us       6.416us           100
                                                 detach         1.01%     522.155us         1.01%     522.155us       5.222us     505.088us         0.94%     505.088us       5.051us           100
                                       aten::as_strided         0.44%     227.884us         0.44%     227.884us       1.139us       0.000us         0.00%       0.000us       0.000us           200
                                            aten::empty         3.20%       1.652ms         3.20%       1.652ms       3.304us       0.000us         0.00%       0.000us       0.000us           500
                                          aten::resize_         1.25%     646.711us         1.25%     646.711us       2.156us       0.000us         0.00%       0.000us       0.000us           300
                                       aten::empty_like         0.79%     407.768us         2.07%       1.067ms       5.334us       0.000us         0.00%       0.000us       0.000us           200
                                    aten::empty_strided         1.52%     785.788us         1.52%     785.788us       3.143us       0.000us         0.00%       0.000us       0.000us           250
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 51.590ms
Self CUDA time total: 53.609ms
ghstack-source-id: 133370215

Test Plan: buck test mode/dev-nosan caffe2/test/:quantization

Reviewed By: raghuramank100

Differential Revision: D29566512

fbshipit-source-id: 1aefca51f99949da7334bcfe504848275c9f952c
2021-07-10 19:43:02 -07:00

299 lines
14 KiB
Python

import torch
from torch.nn import Module
from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
import re
from abc import ABC, abstractmethod
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
class FakeQuantizeBase(ABC, Module):
r""" Base fake quantize module
Any fake quantize implementation should derive from this class.
Concrete fake quantize module should follow the same API. In forward, they will update
the statistics of the observed Tensor and fake quantize the input. They should also provide a
`calculate_qparams` function that computes the quantization parameters given
the collected statistics.
"""
fake_quant_enabled: torch.Tensor
observer_enabled: torch.Tensor
def __init__(self):
super().__init__()
# fake_quant_enabled and observer_enabled are buffers to support their
# replication in DDP. Data type is uint8 because NCCL does not support
# bool tensors.
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def calculate_qparams(self, **kwargs):
pass
@torch.jit.export
def enable_fake_quant(self, enabled: bool = True) -> None:
self.fake_quant_enabled[0] = 1 if enabled else 0
@torch.jit.export
def disable_fake_quant(self):
self.enable_fake_quant(False)
@torch.jit.export
def enable_observer(self, enabled: bool = True) -> None:
self.observer_enabled[0] = 1 if enabled else 0
@torch.jit.export
def disable_observer(self):
self.enable_observer(False)
with_args = classmethod(_with_args)
class FakeQuantize(FakeQuantizeBase):
r""" Simulate the quantize and dequantize operations in training time.
The output of this module is given by
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
* :attr:`scale` defines the scale factor used for quantization.
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
* :attr:`quant_min` specifies the minimum allowable quantized value.
* :attr:`quant_max` specifies the maximum allowable quantized value.
* :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
statistics can still be updated.
* :attr:`observer_enable` controls statistics collection on tensors
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
allowable values are torch.qint8 and torch.quint8. The values of quant_min and
quant_max should be chosen to be consistent with the dtype
Args:
observer (module): Module for observing statistics on input tensors and calculating scale
and zero-point.
quant_min (int): The minimum allowable quantized value.
quant_max (int): The maximum allowable quantized value.
observer_kwargs (optional): Arguments for the observer module
Attributes:
observer (Module): User provided module that collects statistics on the input tensor and
provides a method to calculate scale and zero-point.
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
super().__init__()
assert quant_min <= quant_max, \
'quant_min must be less than or equal to quant_max'
self.quant_min = quant_min
self.quant_max = quant_max
self.activation_post_process = observer(**observer_kwargs)
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = self.activation_post_process.ch_axis \
if hasattr(self.activation_post_process, 'ch_axis') else -1
assert _is_per_channel(self.qscheme) or \
_is_per_tensor(self.qscheme), \
'Only per channel and per tensor quantization are supported in fake quantize' + \
' got qscheme: ' + str(self.qscheme)
self.is_per_channel = _is_per_channel(self.qscheme)
@torch.jit.export
def calculate_qparams(self):
return self.activation_post_process.calculate_qparams()
def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
if self.scale.shape != _scale.shape:
self.scale.resize_(_scale.shape)
self.zero_point.resize_(_zero_point.shape)
self.scale.copy_(_scale)
self.zero_point.copy_(_zero_point)
if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
X, self.scale, self.zero_point,
self.quant_min, self.quant_max)
return X
@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={}, ' \
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
'scale={}, zero_point={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.quant_min, self.quant_max,
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
def _save_to_state_dict(self, destination, prefix, keep_vars):
# We cannot currently register scalar values as buffers, so need to manually
# specify serialization here.
super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'scale'] = self.scale
destination[prefix + 'zero_point'] = self.zero_point
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Removing this function throws an error that the the size of the loaded tensor does not match the original size
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
local_state = ['scale', 'zero_point']
for name in local_state:
key = prefix + name
if key in state_dict:
val = state_dict[key]
# Custom handling to allow loading scale and zero_point
# of size N into uninitialized buffers of size 0. The
# buffers are resized here, and the values are copied in
# the default state_dict loading code of the parent.
if name == 'scale':
self.scale.resize_(val.shape)
else:
assert name == 'zero_point'
self.zero_point.resize_(val.shape)
# For torchscript module we need to update the attributes here since we do not
# call the `_load_from_state_dict` function defined module.py
if torch.jit.is_scripting():
if name == 'scale':
self.scale.copy_(val)
else:
assert name == 'zero_point'
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class FixedQParamsFakeQuantize(FakeQuantizeBase):
""" Simulate quantize and dequantize with fixed quantization
parameters in training time. Only per tensor quantization
is supported.
Args:
`scale` (float): fixed scale for the fake quantize module
`zero_point` (int): fixed zero point for the fake quantize module
`dtype`, `qscheme`, `quant_min`, `quant_max`
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self,
scale,
zero_point,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
quant_min=0,
quant_max=255):
super().__init__()
assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max'
self.quant_min = quant_min
self.quant_max = quant_max
self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
self.dtype = dtype
self.qscheme = qscheme
assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
def forward(self, X):
if self.fake_quant_enabled[0] == 1:
X = torch.fake_quantize_per_tensor_affine(X, self.scale,
self.zero_point, self.quant_min,
self.quant_max)
return X
@torch.jit.export
def calculate_qparams(self):
return self.scale, self.zero_point
@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.scale, self.zero_point, self.dtype,
self.quant_min, self.quant_max, self.qscheme)
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
# TODO(future PR): remove these defaults and enforce activation functions
# to explicitly specify their output range
default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0)
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=True)
def _is_fake_quant_script_module(mod):
''' Returns true if given mod is an instance of FakeQuantize script module.
'''
if isinstance(mod, torch.jit.RecursiveScriptModule):
# qualified name looks like '__torch__.torch.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
suffix = mod._c.qualified_name.split('.', 1)[1]
name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
return name == 'torch.quantization.fake_quantize.FakeQuantize'
return False
def disable_fake_quant(mod):
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.disable_fake_quant()
def enable_fake_quant(mod):
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.enable_fake_quant()
def disable_observer(mod):
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.disable_observer()
def enable_observer(mod):
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
mod.enable_observer()