mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73863 This PR fully aligns the convert function with the design: https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md and simplifies the implementation of convert function by always produce a reference quantized model (with reference patterns) first, and then lower the model to a quantized model that is runnable with PyTorch native backend (fbgemm/qnnpack). This PR makes the convert.py much easier to understand than the previous implementation, and we are able to remove majority of code in quantization_patterns.py as well (in followup PRs). Test Plan: ``` python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestFXNumericSuiteCoreAPIs python test/test_quantization.py TestFXNumericSuiteCoreAPIsModels ``` and other internal/oss regression tests Imported from OSS Reviewed By: andrewor14 Differential Revision: D34778506 fbshipit-source-id: 0678b66addf736039a8749b352f6f569caca962b (cherry picked from commit 33ec9caf23f3ab373d827117efbd9db0668b2437)
127 lines
5.8 KiB
Python
127 lines
5.8 KiB
Python
import torch
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.intrinsic as nni
|
|
from torch.nn.quantized.modules.utils import _quantize_weight
|
|
|
|
class Linear(nnq.Linear):
|
|
r"""
|
|
A dynamic quantized linear module with floating point tensor as inputs and outputs.
|
|
We adopt the same interface as `torch.nn.Linear`, please see
|
|
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
|
|
|
|
Similar to :class:`torch.nn.Linear`, attributes will be randomly
|
|
initialized at module creation time and will be overwritten later
|
|
|
|
Attributes:
|
|
weight (Tensor): the non-learnable quantized weights of the module which are of
|
|
shape :math:`(\text{out\_features}, \text{in\_features})`.
|
|
bias (Tensor): the non-learnable floating point bias of the module of shape
|
|
:math:`(\text{out\_features})`. If :attr:`bias` is ``True``,
|
|
the values are initialized to zero.
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.quantized.dynamic.Linear(20, 30)
|
|
>>> input = torch.randn(128, 20)
|
|
>>> output = m(input)
|
|
>>> print(output.size())
|
|
torch.Size([128, 30])
|
|
"""
|
|
# version used in this class is different from the parent class nnq.Linear
|
|
_version = 4
|
|
|
|
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
|
|
super(Linear, self).__init__(in_features, out_features, bias_, dtype=dtype)
|
|
# We don't muck around with buffers or attributes or anything here
|
|
# to keep the module simple. *everything* is simply a Python attribute.
|
|
# Serialization logic is explicitly handled in the below serialization and
|
|
# deserialization modules
|
|
self.version = 4
|
|
|
|
def forward(self, x):
|
|
# Note that we can handle self.bias == None case.
|
|
if self._packed_params.dtype == torch.qint8:
|
|
if self.version is None or self.version < 4:
|
|
Y = torch.ops.quantized.linear_dynamic(
|
|
x, self._packed_params._packed_params)
|
|
else:
|
|
Y = torch.ops.quantized.linear_dynamic(
|
|
x, self._packed_params._packed_params, reduce_range=True)
|
|
elif self._packed_params.dtype == torch.float16:
|
|
Y = torch.ops.quantized.linear_dynamic_fp16(
|
|
x, self._packed_params._packed_params)
|
|
else:
|
|
raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
|
|
return Y.to(x.dtype)
|
|
|
|
def _get_name(self):
|
|
return 'DynamicQuantizedLinear'
|
|
|
|
def extra_repr(self):
|
|
extra_repr_str = 'in_features={}, out_features={}, dtype={}'.format(
|
|
self.in_features, self.out_features, self._packed_params.dtype
|
|
)
|
|
if self._packed_params.dtype == torch.qint8:
|
|
extra_repr_str += ', qscheme={}'.format(self.weight().qscheme())
|
|
return extra_repr_str
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
version = local_metadata.get('version', None)
|
|
self.version = version
|
|
super(Linear, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
r"""Create a dynamic quantized module from a float module or qparams_dict
|
|
|
|
Args:
|
|
mod (Module): a float module, either produced by torch.ao.quantization
|
|
utilities or provided by the user
|
|
"""
|
|
float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
|
|
torch.nn.intrinsic.modules.fused.LinearReLU, torch.nn.qat.dynamic.Linear]
|
|
|
|
assert type(mod) in float_modules, \
|
|
'nn.quantized.dynamic.Linear.from_float only works for one of' + \
|
|
str([float_mod.__name__ for float_mod in float_modules])
|
|
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
|
if type(mod) == nni.LinearReLU:
|
|
mod = mod[0]
|
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
weight_observer = mod.qconfig.weight()
|
|
else:
|
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
|
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
|
# import until we need it.
|
|
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
weight_observer = default_dynamic_qconfig.weight()
|
|
dtype = weight_observer.dtype
|
|
assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \
|
|
"dynamic quantized linear are qint8 and float16 got: {}".format(dtype)
|
|
weight_observer(mod.weight)
|
|
if dtype == torch.qint8:
|
|
qweight = _quantize_weight(mod.weight.float(), weight_observer)
|
|
elif dtype == torch.float16:
|
|
qweight = mod.weight.float()
|
|
else:
|
|
raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!')
|
|
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
|
qlinear.set_weight_bias(qweight, mod.bias)
|
|
return qlinear
|
|
|
|
@classmethod
|
|
def from_reference(cls, ref_qlinear):
|
|
""" Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
|
|
module
|
|
Args:
|
|
ref_qlinear (Module): a reference quantized module, either produced by
|
|
torch.ao.quantization functions or provided by the user
|
|
"""
|
|
qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features, dtype=ref_qlinear.weight_dtype)
|
|
qweight = ref_qlinear.get_quantized_weight()
|
|
bias = ref_qlinear.bias
|
|
qlinear.set_weight_bias(qweight, bias)
|
|
return qlinear
|