pytorch/torch/nn/quantized/dynamic/modules/linear.py
Jianyu Huang 8224398c14 [pytorch] Fix the extra_repr print message for float16 dynamic quantization (#36044)
Summary:
When applying the float16 dynamic quantization with
```
        model = torch.quantization.quantize_dynamic(
            model, {torch.nn.Linear}, dtype=torch.float16
        )
        print(model)
```
there is an issue when we try to print the model. Basically we cannot print the `qscheme` information for float16 weight (It is not per-tensor or per-channel quantization defined for int8 dynamic quantization).

Before this PR:
```
Traceback (most recent call last):
  File "dlrm_s_pytorch.py", line 860, in <module>
    print(dlrm)
  File "/home/jianyuhuang/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1142, in __repr__
    mod_str = repr(module)
  File "/home/jianyuhuang/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1142, in __repr__
    mod_str = repr(module)
  File "/home/jianyuhuang/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1136, in __repr__
    extra_repr = self.extra_repr()
  File "/home/jianyuhuang/miniconda3/lib/python3.7/site-packages/torch/nn/quantized/dynamic/modules/linear.py", line 55, in extra_repr
    self.in_features, self.out_features, self.weight().qscheme()
RuntimeError: Could not run 'aten::qscheme' with arguments from the 'CPUTensorId' backend. 'aten::qscheme' is only available for these back
ends: [QuantizedCPUTensorId, VariableTensorId].
```

After this PR:
```
    (4): DynamicQuantizedLinear(
      in_features=2, out_features=1, dtype=torch.float16
      (_packed_params): LinearPackedParams()
    )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36044

Differential Revision: D20860811

Pulled By: jianyuh

fbshipit-source-id: d1405a185f46a8110e6d27982b40534c854f4d1c
2020-04-05 14:27:42 -07:00

91 lines
4.2 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from ....modules.linear import Linear as NNLinear
import torch.nn.quantized as nnq
from torch.nn.quantized.modules.utils import _quantize_weight
class Linear(nnq.Linear):
r"""
A dynamic quantized linear module with quantized tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module which are of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized to zero.
Examples::
>>> m = nn.quantized.dynamic.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
super(Linear, self).__init__(in_features, out_features, bias_, dtype=dtype)
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
def forward(self, x):
# Note that we can handle self.bias == None case.
if self._packed_params.dtype == torch.qint8:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params)
elif self._packed_params.dtype == torch.float16:
Y = torch.ops.quantized.linear_dynamic_fp16(
x, self._packed_params._packed_params)
else:
raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
return Y.to(x.dtype)
def _get_name(self):
return 'DynamicQuantizedLinear'
def extra_repr(self):
extra_repr_str = 'in_features={}, out_features={}, dtype={}'.format(
self.in_features, self.out_features, self._packed_params.dtype
)
if self._packed_params.dtype == torch.qint8:
extra_repr_str += ', qscheme={}'.format(self.weight().qscheme())
return extra_repr_str
@classmethod
def from_float(cls, mod):
r"""Create a dynamic quantized module from a float module or qparams_dict
Args:
mod (Module): a float module, either produced by torch.quantization
utilities or provided by the user
"""
assert type(mod) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear'
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.quantization.qconfig import default_dynamic_qconfig
weight_observer = default_dynamic_qconfig.weight()
dtype = weight_observer.dtype
assert dtype in [torch.qint8, torch.float16], 'The only supported dtypes for dynamic quantized linear are qint8 and float16'
weight_observer(mod.weight)
if dtype == torch.qint8:
qweight = _quantize_weight(mod.weight.float(), weight_observer)
elif dtype == torch.float16:
qweight = mod.weight.float()
else:
raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!')
qlinear = Linear(mod.in_features, mod.out_features, dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
return qlinear