mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use noop observer to pass dtype for dynamic quantization (#26709)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26709 Polishes implementation from #25975. Primarily, we use NoopObserver to communicate that weights need to be quantized to float16. The very top-level API (quantize_dynamic) stays the same with `dtype` argument but the implementation follows the common flow. One can argue that dynamic fp16 quantization doesn't really fit into the 'observer' mechanism. It's in fact not ideal, but it's better to have the same flow than branching on both dtype and qconfig. Test Plan: Imported from OSS Differential Revision: D17544103 Pulled By: dzhulgakov fbshipit-source-id: 6af3f18c35929a1a53ea734079c005f656e4925f
This commit is contained in:
parent
ae0732cde3
commit
128a65e2e0
|
|
@ -485,20 +485,8 @@ class PostTrainingDynamicQuantTest(QuantizationTestCase):
|
|||
|
||||
ref = copy.deepcopy(cell)
|
||||
|
||||
qconfig_dynamic_dict = {
|
||||
torch.nn.LSTM: default_dynamic_qconfig,
|
||||
}
|
||||
default_dynamic_module_mapping = {
|
||||
torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM,
|
||||
}
|
||||
model_int8 = quantize_dynamic(
|
||||
model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
|
||||
dtype=torch.qint8
|
||||
)
|
||||
model_fp16 = quantize_dynamic(
|
||||
model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
|
||||
dtype=torch.float16
|
||||
)
|
||||
model_int8 = quantize_dynamic(model=model, dtype=torch.qint8)
|
||||
model_fp16 = quantize_dynamic(model=model, dtype=torch.float16)
|
||||
cell_int8 = model_int8.lstm
|
||||
cell_fp16 = model_fp16.lstm
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class Linear(nnq.Linear):
|
|||
"""
|
||||
assert type(mod) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear'
|
||||
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||
if mod.qconfig is not None and mod.qconfig.weight() is not None:
|
||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||
weight_observer = mod.qconfig.weight()
|
||||
else:
|
||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||
|
|
|
|||
|
|
@ -201,39 +201,35 @@ class RNNBase(torch.nn.Module):
|
|||
self._all_weight_values.append(torch.ops.quantized.linear_prepack(*dynamic_vals[i]))
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, dtype=torch.qint8):
|
||||
def from_float(cls, mod):
|
||||
assert type(mod) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM'
|
||||
assert hasattr(
|
||||
mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||
|
||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||
weight_observer = mod.qconfig.weight()
|
||||
else:
|
||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
||||
# import until we need it.
|
||||
from torch.quantization.QConfig import default_dynamic_qconfig
|
||||
weight_observer = default_dynamic_qconfig.weight()
|
||||
|
||||
dtype = weight_observer.dtype
|
||||
supported_scalar_types = [torch.qint8, torch.float16]
|
||||
if dtype not in supported_scalar_types:
|
||||
raise RuntimeError('Unsupported dtype: {}'.format(dtype))
|
||||
|
||||
# When dtype = torch.float16, we don't need weight_observer
|
||||
if dtype == torch.qint8:
|
||||
if mod.qconfig is not None and mod.qconfig.weight() is not None:
|
||||
weight_observer = mod.qconfig.weight()
|
||||
else:
|
||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
||||
# import until we need it.
|
||||
from torch.quantization.QConfig import default_dynamic_qconfig
|
||||
weight_observer = default_dynamic_qconfig.weight()
|
||||
assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
|
||||
raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))
|
||||
|
||||
if mod.mode == 'LSTM':
|
||||
qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
|
||||
mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
|
||||
else:
|
||||
raise NotImplementedError('Only LSTM is supported for QuantizedRNN for now')
|
||||
|
||||
num_directions = 2 if mod.bidirectional else 1
|
||||
|
||||
assert mod.bias
|
||||
|
||||
# TODO: support more than just LSTM
|
||||
if qRNNBase.mode != 'LSTM':
|
||||
raise RuntimeError('Only LSTM is supported for QuantizedRNN')
|
||||
|
||||
qRNNBase._all_weight_names = []
|
||||
qRNNBase._all_weight_values = []
|
||||
for layer in range(qRNNBase.num_layers):
|
||||
|
|
@ -372,5 +368,5 @@ class LSTM(RNNBase):
|
|||
return self.forward_tensor(input, hx)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, dtype=torch.qint8):
|
||||
return super(LSTM, cls).from_float(mod, dtype)
|
||||
def from_float(cls, mod):
|
||||
return super(LSTM, cls).from_float(mod)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class QConfigDynamic(namedtuple('QConfigDynamic', ['weight'])):
|
|||
return super(QConfigDynamic, cls).__new__(cls, weight)
|
||||
|
||||
default_dynamic_qconfig = QConfigDynamic(weight=default_weight_observer)
|
||||
float16_dynamic_qconfig = QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16))
|
||||
|
||||
default_qat_qconfig = QConfig(activation=default_fake_quant,
|
||||
weight=default_weight_fake_quant)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ _all__ = [
|
|||
'Observer', 'WeightObserver', 'observer', 'default_observer',
|
||||
'default_weight_observer',
|
||||
# QConfig
|
||||
'QConfig', 'default_qconfig', 'default_dynamic_qconfig',
|
||||
'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig',
|
||||
# QAT utilities
|
||||
'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
|
||||
# module transformations
|
||||
|
|
|
|||
|
|
@ -39,21 +39,39 @@ _PartialWrapper.with_args = _with_args
|
|||
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
|
||||
|
||||
|
||||
class ObserverBase(ABC, nn.Module):
|
||||
r"""Observer base Module
|
||||
Any concrete observer implementation should derive from this class.
|
||||
class Observer(ABC, nn.Module):
|
||||
r"""
|
||||
Observer base Module. Any observer implementation should derive from this class.
|
||||
|
||||
Concrete observers should follow the same API. In forward, they will update
|
||||
the statistics of the observed Tensor. And they should provide a
|
||||
`calculate_qparams` function that computes the quantization parameters given
|
||||
the collected statistics.
|
||||
"""
|
||||
def __init__(self, dtype):
|
||||
super(Observer, self).__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def calculate_qparams(self, **kwargs):
|
||||
pass
|
||||
|
||||
with_args = classmethod(_with_args)
|
||||
|
||||
|
||||
class _ObserverBase(Observer):
|
||||
r"""
|
||||
Common base for all qint/quint8 observers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False
|
||||
):
|
||||
super(ObserverBase, self).__init__()
|
||||
self.dtype = dtype
|
||||
super(_ObserverBase, self).__init__(dtype=dtype)
|
||||
self.qscheme = qscheme
|
||||
self.reduce_range = reduce_range
|
||||
|
||||
|
|
@ -71,14 +89,6 @@ class ObserverBase(ABC, nn.Module):
|
|||
torch.quint8,
|
||||
), "Default Observer only works for qint8 and quint8 data type"
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def calculate_qparams(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _calculate_per_channel_qparams(self, min_vals, max_vals):
|
||||
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
||||
"""
|
||||
|
|
@ -158,10 +168,8 @@ class ObserverBase(ABC, nn.Module):
|
|||
|
||||
return torch.tensor([scale]), torch.tensor([zero_point])
|
||||
|
||||
with_args = classmethod(_with_args)
|
||||
|
||||
|
||||
class MinMaxObserver(ObserverBase):
|
||||
class MinMaxObserver(_ObserverBase):
|
||||
r"""Default Observer Module
|
||||
A default implementation of the observer module, only works for
|
||||
`per_tensor_affine` quantization scheme. The module will record the
|
||||
|
|
@ -216,7 +224,7 @@ class MinMaxObserver(ObserverBase):
|
|||
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
|
||||
|
||||
|
||||
class PerChannelMinMaxObserver(ObserverBase):
|
||||
class PerChannelMinMaxObserver(_ObserverBase):
|
||||
r"""Per Channel Observer Module
|
||||
The module will record the running average of max and min value for each
|
||||
channel of the observed Tensor and calculate_qparams will calculate
|
||||
|
|
@ -266,7 +274,7 @@ class PerChannelMinMaxObserver(ObserverBase):
|
|||
|
||||
|
||||
|
||||
class HistogramObserver(ObserverBase):
|
||||
class HistogramObserver(_ObserverBase):
|
||||
r"""
|
||||
The module records the running histogram of tensor values along with
|
||||
min/max values. calculate_qparams will calculate scale and zero_point
|
||||
|
|
@ -521,7 +529,7 @@ class HistogramObserver(ObserverBase):
|
|||
return self._calculate_qparams(new_min.item(), new_max.item())
|
||||
|
||||
|
||||
class RecordingObserver(ObserverBase):
|
||||
class RecordingObserver(_ObserverBase):
|
||||
r"""
|
||||
The module is mainly for debug and records the tensor values during runtime
|
||||
"""
|
||||
|
|
@ -544,6 +552,26 @@ class RecordingObserver(ObserverBase):
|
|||
return self.tensor_val
|
||||
|
||||
|
||||
class NoopObserver(Observer):
|
||||
r"""
|
||||
Observer that doesn't do anything and just passes its configuration to the
|
||||
quantized module's ``.from_float()`.
|
||||
|
||||
Primarily used for quantization to float16 which doesn't require determining
|
||||
ranges.
|
||||
"""
|
||||
def __init__(self, dtype=torch.float16):
|
||||
if dtype != torch.float16:
|
||||
raise ValueError("Only float16 quantization can be used without calibration process")
|
||||
super(NoopObserver, self).__init__(dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
def calculate_qparams(self):
|
||||
raise Exception("calculate_qparams should not be called for NoopObserver")
|
||||
|
||||
|
||||
# Restrict activations to be in the range (0,127)
|
||||
default_observer = MinMaxObserver.with_args(reduce_range=True)
|
||||
default_debug_observer = RecordingObserver
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch.nn._intrinsic.quantized as nniq
|
|||
import torch.nn._intrinsic.qat as nniqat
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
from .QConfig import default_dynamic_qconfig
|
||||
from .QConfig import default_dynamic_qconfig, float16_dynamic_qconfig
|
||||
import torch.nn.qat as nnqat
|
||||
|
||||
|
||||
|
|
@ -256,20 +256,47 @@ def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING):
|
|||
convert(model, mapping)
|
||||
return model
|
||||
|
||||
DEFAULT_QCONFIG_DICT = {
|
||||
nn.Linear : default_dynamic_qconfig,
|
||||
nn.LSTM : default_dynamic_qconfig,
|
||||
}
|
||||
def quantize_dynamic(model, qconfig_dict=None, dtype=torch.qint8, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING):
|
||||
r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
|
||||
|
||||
def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING, dtype=torch.qint8):
|
||||
r"""Converts a float model to dynamic quantized model.
|
||||
Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
|
||||
|
||||
Perform dynamic training and output a quantized model.
|
||||
For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
|
||||
by default is performed for layers with large weights size - i.e. Linear and RNN variants.
|
||||
|
||||
Fine grained control is possible with `qconfig_dict` and `mapping` that act similarly to `quantize()`.
|
||||
If `qconfig_dict` is provided, the `dtype` argument is ignored.
|
||||
|
||||
Args:
|
||||
module: input model
|
||||
qconfig_dict: dictionary that maps from name or type of submodule to quantization
|
||||
configuration, qconfig applies to all submodules of a given
|
||||
module unless qconfig for the submodules are specified (when the
|
||||
submodule already has qconfig attribute). Entries in the dictionary
|
||||
need to be QConfigDynamic instances.
|
||||
mapping: maps type of a submodule to a type of corresponding dynamically quantized version
|
||||
with which the submodule needs to be replaced
|
||||
"""
|
||||
if qconfig_dict is None:
|
||||
if dtype == torch.qint8:
|
||||
qconfig_dict = {
|
||||
nn.Linear : default_dynamic_qconfig,
|
||||
nn.LSTM : default_dynamic_qconfig,
|
||||
}
|
||||
elif dtype == torch.float16:
|
||||
qconfig_dict = {
|
||||
# TODO: uncomment when float16 Linear support is added
|
||||
# nn.Linear : default_dynamic_qconfig,
|
||||
nn.LSTM : float16_dynamic_qconfig,
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
"Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
|
||||
|
||||
model = copy.deepcopy(model)
|
||||
model.eval()
|
||||
propagate_qconfig(model, qconfig_dict)
|
||||
convert(model, mapping, dtype)
|
||||
convert(model, mapping)
|
||||
return model
|
||||
|
||||
def prepare_qat(model):
|
||||
|
|
@ -295,7 +322,7 @@ def quantize_qat(model, run_fn, run_args):
|
|||
convert(model)
|
||||
return model
|
||||
|
||||
def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
|
||||
def convert(module, mapping=DEFAULT_MODULE_MAPPING):
|
||||
r"""Converts the float module with observers(where we can get quantization
|
||||
parameters) to a quantized module.
|
||||
Args:
|
||||
|
|
@ -312,13 +339,13 @@ def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
|
|||
|
||||
for name, mod in module.named_children():
|
||||
if type(mod) not in SWAPPABLE_MODULES:
|
||||
convert(mod, mapping, dtype)
|
||||
reassign[name] = swap_module(mod, mapping, dtype)
|
||||
convert(mod, mapping)
|
||||
reassign[name] = swap_module(mod, mapping)
|
||||
|
||||
for key, value in reassign.items():
|
||||
module._modules[key] = value
|
||||
|
||||
def swap_module(mod, mapping, dtype=torch.qint8):
|
||||
def swap_module(mod, mapping):
|
||||
r"""Swaps the module if it has a quantized counterpart and it has an
|
||||
`observer` attached.
|
||||
|
||||
|
|
@ -332,14 +359,7 @@ def swap_module(mod, mapping, dtype=torch.qint8):
|
|||
new_mod = mod
|
||||
if hasattr(mod, 'qconfig') and mod.qconfig is not None:
|
||||
if type(mod) in mapping:
|
||||
supported_scalar_types = [torch.qint8, torch.float16]
|
||||
if dtype not in supported_scalar_types:
|
||||
raise RuntimeError('Unsupported dtype: {}'.format(dtype))
|
||||
if dtype == torch.qint8:
|
||||
new_mod = mapping[type(mod)].from_float(mod)
|
||||
elif dtype == torch.float16:
|
||||
# We want to support float16 dynamic quantization
|
||||
new_mod = mapping[type(mod)].from_float(mod, dtype)
|
||||
new_mod = mapping[type(mod)].from_float(mod)
|
||||
return new_mod
|
||||
|
||||
def get_observer_dict(mod, target_dict, prefix=""):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user