Use noop observer to pass dtype for dynamic quantization (#26709)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26709

Polishes implementation from #25975. Primarily, we use NoopObserver to communicate that weights need to be quantized to float16. The very top-level API (quantize_dynamic) stays the same with `dtype` argument but the implementation follows the common flow.

One can argue that dynamic fp16 quantization doesn't really fit into the 'observer' mechanism. It's in fact not ideal, but it's better to have the same flow than branching on both dtype and qconfig.

Test Plan: Imported from OSS

Differential Revision: D17544103

Pulled By: dzhulgakov

fbshipit-source-id: 6af3f18c35929a1a53ea734079c005f656e4925f
This commit is contained in:
Dmytro Dzhulgakov 2019-09-24 09:19:15 -07:00 committed by Facebook Github Bot
parent ae0732cde3
commit 128a65e2e0
7 changed files with 109 additions and 76 deletions

View File

@ -485,20 +485,8 @@ class PostTrainingDynamicQuantTest(QuantizationTestCase):
ref = copy.deepcopy(cell)
qconfig_dynamic_dict = {
torch.nn.LSTM: default_dynamic_qconfig,
}
default_dynamic_module_mapping = {
torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM,
}
model_int8 = quantize_dynamic(
model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
dtype=torch.qint8
)
model_fp16 = quantize_dynamic(
model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
dtype=torch.float16
)
model_int8 = quantize_dynamic(model=model, dtype=torch.qint8)
model_fp16 = quantize_dynamic(model=model, dtype=torch.float16)
cell_int8 = model_int8.lstm
cell_fp16 = model_fp16.lstm

View File

@ -52,7 +52,7 @@ class Linear(nnq.Linear):
"""
assert type(mod) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear'
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
if mod.qconfig is not None and mod.qconfig.weight() is not None:
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:

View File

@ -201,39 +201,35 @@ class RNNBase(torch.nn.Module):
self._all_weight_values.append(torch.ops.quantized.linear_prepack(*dynamic_vals[i]))
@classmethod
def from_float(cls, mod, dtype=torch.qint8):
def from_float(cls, mod):
assert type(mod) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM'
assert hasattr(
mod, 'qconfig'), 'Input float module must have qconfig defined'
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.quantization.QConfig import default_dynamic_qconfig
weight_observer = default_dynamic_qconfig.weight()
dtype = weight_observer.dtype
supported_scalar_types = [torch.qint8, torch.float16]
if dtype not in supported_scalar_types:
raise RuntimeError('Unsupported dtype: {}'.format(dtype))
# When dtype = torch.float16, we don't need weight_observer
if dtype == torch.qint8:
if mod.qconfig is not None and mod.qconfig.weight() is not None:
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.quantization.QConfig import default_dynamic_qconfig
weight_observer = default_dynamic_qconfig.weight()
assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))
if mod.mode == 'LSTM':
qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
else:
raise NotImplementedError('Only LSTM is supported for QuantizedRNN for now')
num_directions = 2 if mod.bidirectional else 1
assert mod.bias
# TODO: support more than just LSTM
if qRNNBase.mode != 'LSTM':
raise RuntimeError('Only LSTM is supported for QuantizedRNN')
qRNNBase._all_weight_names = []
qRNNBase._all_weight_values = []
for layer in range(qRNNBase.num_layers):
@ -372,5 +368,5 @@ class LSTM(RNNBase):
return self.forward_tensor(input, hx)
@classmethod
def from_float(cls, mod, dtype=torch.qint8):
return super(LSTM, cls).from_float(mod, dtype)
def from_float(cls, mod):
return super(LSTM, cls).from_float(mod)

View File

@ -58,6 +58,7 @@ class QConfigDynamic(namedtuple('QConfigDynamic', ['weight'])):
return super(QConfigDynamic, cls).__new__(cls, weight)
default_dynamic_qconfig = QConfigDynamic(weight=default_weight_observer)
float16_dynamic_qconfig = QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16))
default_qat_qconfig = QConfig(activation=default_fake_quant,
weight=default_weight_fake_quant)

View File

@ -26,7 +26,7 @@ _all__ = [
'Observer', 'WeightObserver', 'observer', 'default_observer',
'default_weight_observer',
# QConfig
'QConfig', 'default_qconfig', 'default_dynamic_qconfig',
'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig',
# QAT utilities
'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
# module transformations

View File

@ -39,21 +39,39 @@ _PartialWrapper.with_args = _with_args
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
class ObserverBase(ABC, nn.Module):
r"""Observer base Module
Any concrete observer implementation should derive from this class.
class Observer(ABC, nn.Module):
r"""
Observer base Module. Any observer implementation should derive from this class.
Concrete observers should follow the same API. In forward, they will update
the statistics of the observed Tensor. And they should provide a
`calculate_qparams` function that computes the quantization parameters given
the collected statistics.
"""
def __init__(self, dtype):
super(Observer, self).__init__()
self.dtype = dtype
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def calculate_qparams(self, **kwargs):
pass
with_args = classmethod(_with_args)
class _ObserverBase(Observer):
r"""
Common base for all qint/quint8 observers
"""
def __init__(
self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False
):
super(ObserverBase, self).__init__()
self.dtype = dtype
super(_ObserverBase, self).__init__(dtype=dtype)
self.qscheme = qscheme
self.reduce_range = reduce_range
@ -71,14 +89,6 @@ class ObserverBase(ABC, nn.Module):
torch.quint8,
), "Default Observer only works for qint8 and quint8 data type"
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def calculate_qparams(self, **kwargs):
pass
def _calculate_per_channel_qparams(self, min_vals, max_vals):
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
"""
@ -158,10 +168,8 @@ class ObserverBase(ABC, nn.Module):
return torch.tensor([scale]), torch.tensor([zero_point])
with_args = classmethod(_with_args)
class MinMaxObserver(ObserverBase):
class MinMaxObserver(_ObserverBase):
r"""Default Observer Module
A default implementation of the observer module, only works for
`per_tensor_affine` quantization scheme. The module will record the
@ -216,7 +224,7 @@ class MinMaxObserver(ObserverBase):
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
class PerChannelMinMaxObserver(ObserverBase):
class PerChannelMinMaxObserver(_ObserverBase):
r"""Per Channel Observer Module
The module will record the running average of max and min value for each
channel of the observed Tensor and calculate_qparams will calculate
@ -266,7 +274,7 @@ class PerChannelMinMaxObserver(ObserverBase):
class HistogramObserver(ObserverBase):
class HistogramObserver(_ObserverBase):
r"""
The module records the running histogram of tensor values along with
min/max values. calculate_qparams will calculate scale and zero_point
@ -521,7 +529,7 @@ class HistogramObserver(ObserverBase):
return self._calculate_qparams(new_min.item(), new_max.item())
class RecordingObserver(ObserverBase):
class RecordingObserver(_ObserverBase):
r"""
The module is mainly for debug and records the tensor values during runtime
"""
@ -544,6 +552,26 @@ class RecordingObserver(ObserverBase):
return self.tensor_val
class NoopObserver(Observer):
r"""
Observer that doesn't do anything and just passes its configuration to the
quantized module's ``.from_float()`.
Primarily used for quantization to float16 which doesn't require determining
ranges.
"""
def __init__(self, dtype=torch.float16):
if dtype != torch.float16:
raise ValueError("Only float16 quantization can be used without calibration process")
super(NoopObserver, self).__init__(dtype=dtype)
def forward(self, x):
return x
def calculate_qparams(self):
raise Exception("calculate_qparams should not be called for NoopObserver")
# Restrict activations to be in the range (0,127)
default_observer = MinMaxObserver.with_args(reduce_range=True)
default_debug_observer = RecordingObserver

View File

@ -8,7 +8,7 @@ import torch.nn._intrinsic.quantized as nniq
import torch.nn._intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from .QConfig import default_dynamic_qconfig
from .QConfig import default_dynamic_qconfig, float16_dynamic_qconfig
import torch.nn.qat as nnqat
@ -256,20 +256,47 @@ def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING):
convert(model, mapping)
return model
DEFAULT_QCONFIG_DICT = {
nn.Linear : default_dynamic_qconfig,
nn.LSTM : default_dynamic_qconfig,
}
def quantize_dynamic(model, qconfig_dict=None, dtype=torch.qint8, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING):
r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING, dtype=torch.qint8):
r"""Converts a float model to dynamic quantized model.
Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
Perform dynamic training and output a quantized model.
For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
by default is performed for layers with large weights size - i.e. Linear and RNN variants.
Fine grained control is possible with `qconfig_dict` and `mapping` that act similarly to `quantize()`.
If `qconfig_dict` is provided, the `dtype` argument is ignored.
Args:
module: input model
qconfig_dict: dictionary that maps from name or type of submodule to quantization
configuration, qconfig applies to all submodules of a given
module unless qconfig for the submodules are specified (when the
submodule already has qconfig attribute). Entries in the dictionary
need to be QConfigDynamic instances.
mapping: maps type of a submodule to a type of corresponding dynamically quantized version
with which the submodule needs to be replaced
"""
if qconfig_dict is None:
if dtype == torch.qint8:
qconfig_dict = {
nn.Linear : default_dynamic_qconfig,
nn.LSTM : default_dynamic_qconfig,
}
elif dtype == torch.float16:
qconfig_dict = {
# TODO: uncomment when float16 Linear support is added
# nn.Linear : default_dynamic_qconfig,
nn.LSTM : float16_dynamic_qconfig,
}
else:
raise ValueError(
"Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
model = copy.deepcopy(model)
model.eval()
propagate_qconfig(model, qconfig_dict)
convert(model, mapping, dtype)
convert(model, mapping)
return model
def prepare_qat(model):
@ -295,7 +322,7 @@ def quantize_qat(model, run_fn, run_args):
convert(model)
return model
def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
def convert(module, mapping=DEFAULT_MODULE_MAPPING):
r"""Converts the float module with observers(where we can get quantization
parameters) to a quantized module.
Args:
@ -312,13 +339,13 @@ def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
for name, mod in module.named_children():
if type(mod) not in SWAPPABLE_MODULES:
convert(mod, mapping, dtype)
reassign[name] = swap_module(mod, mapping, dtype)
convert(mod, mapping)
reassign[name] = swap_module(mod, mapping)
for key, value in reassign.items():
module._modules[key] = value
def swap_module(mod, mapping, dtype=torch.qint8):
def swap_module(mod, mapping):
r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached.
@ -332,14 +359,7 @@ def swap_module(mod, mapping, dtype=torch.qint8):
new_mod = mod
if hasattr(mod, 'qconfig') and mod.qconfig is not None:
if type(mod) in mapping:
supported_scalar_types = [torch.qint8, torch.float16]
if dtype not in supported_scalar_types:
raise RuntimeError('Unsupported dtype: {}'.format(dtype))
if dtype == torch.qint8:
new_mod = mapping[type(mod)].from_float(mod)
elif dtype == torch.float16:
# We want to support float16 dynamic quantization
new_mod = mapping[type(mod)].from_float(mod, dtype)
new_mod = mapping[type(mod)].from_float(mod)
return new_mod
def get_observer_dict(mod, target_dict, prefix=""):