pytorch/torch/quantization
Supriya Rao 864d129bae [quant][fx] Remove extra q-dq for weight bias in normalization ops (#59882)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59882

Currently for normalization ops, the weight and bias arguments are treated as activationn inputs which require observers.
This results in adding extra quant-dequant ops for the weight and bias inputs.

This PR adds support to skip observing weight/bias inputs of norm operators, thus removing the redundant q-dq ops

Quantized graph with F.layer_norm
Before this PR
```
def forward(self, x):
    _input_scale_0 = self._input_scale_0
    _input_zero_point_0 = self._input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, _input_scale_0, _input_zero_point_0, torch.quint8);  x = _input_scale_0 = _input_zero_point_0 = None
    scale = self.scale
    _input_scale_1 = self._input_scale_1
    _input_zero_point_1 = self._input_zero_point_1
    quantize_per_tensor_1 = torch.quantize_per_tensor(scale, _input_scale_1, _input_zero_point_1, torch.quint8);  scale = _input_scale_1 = _input_zero_point_1 = None
    bias = self.bias
    _input_scale_2 = self._input_scale_2
    _input_zero_point_2 = self._input_zero_point_2
    quantize_per_tensor_2 = torch.quantize_per_tensor(bias, _input_scale_2, _input_zero_point_2, torch.quint8);  bias = _input_scale_2 = _input_zero_point_2 = None
    _scale_0 = self._scale_0
    _zero_point_0 = self._zero_point_0
    dequantize = quantize_per_tensor_1.dequantize();  quantize_per_tensor_1 = None
    dequantize_1 = quantize_per_tensor_2.dequantize();  quantize_per_tensor_2 = None
    layer_norm = torch.ops.quantized.layer_norm(quantize_per_tensor, [2, 5, 5], weight = dequantize, bias = dequantize_1, eps = 1e-05, output_scale = _scale_0, output_zero_point = _zero_point_0);  quantize_per_tensor = dequantize = dequantize_1 = _scale_0 = _zero_point_0 = None
    dequantize_2 = layer_norm.dequantize();  layer_norm = None
    return dequantize_2
```
After
```
def forward(self, x):
    _input_scale_0 = self._input_scale_0
    _input_zero_point_0 = self._input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, _input_scale_0, _input_zero_point_0, torch.quint8);  x = _input_scale_0 = _input_zero_point_0 = None
    scale = self.scale
    bias = self.bias
    _scale_0 = self._scale_0
    _zero_point_0 = self._zero_point_0
    layer_norm = torch.ops.quantized.layer_norm(quantize_per_tensor, [2, 5, 5], weight = scale, bias = bias, eps = 1e-05, output_scale = _scale_0, output_zero_point = _zero_point_0);  quantize_per_tensor = scale = bias = _scale_0 = _zero_point_0 = None
    dequantize = layer_norm.dequantize();  layer_norm = None
    return dequantize
```

Test Plan:
python test/test_quantization.py TestQuantizeFxOps.test_norm_weight_bias

Imported from OSS

Reviewed By: HDCharles, ailzhang

Differential Revision: D29068203

fbshipit-source-id: 24b5c38bbea5fd355d34522bfa654c9db18607da
2021-06-11 16:22:36 -07:00
..
fx [quant][fx] Remove extra q-dq for weight bias in normalization ops (#59882) 2021-06-11 16:22:36 -07:00
ns [quant][graphmode][fx][refactor] Split quantize.py to prepare.py and convert.py (#59353) 2021-06-02 23:52:39 -07:00
__init__.py Un-ignore F403 in .flake8 (#55838) 2021-04-13 09:24:07 -07:00
_correct_bias.py
_equalize.py [quant] Eager mode equalization support for ConvReLU and LinearReLU (#58792) 2021-05-24 17:25:13 -07:00
_learnable_fake_quantize.py [quant][graphmode][fx][refactor] Split quantize.py to prepare.py and convert.py (#59353) 2021-06-02 23:52:39 -07:00
_numeric_suite_fx.py ns for fx: move relatedness mapping to mappings file (#57171) 2021-05-05 06:29:11 -07:00
_numeric_suite.py Add lint for unqualified type: ignore (#56290) 2021-04-21 08:07:23 -07:00
fake_quantize.py memory efficient per-channel fq: use it everywhere, delete old version (#51265) 2021-01-28 19:42:25 -08:00
fuse_modules.py Add lint for unqualified noqa (#56272) 2021-04-19 13:16:18 -07:00
fuser_method_mappings.py fix docstring for fusing functions (#58638) 2021-05-24 18:27:22 -07:00
observer.py Move _PartialWrapper to module scope (#59660) 2021-06-09 11:55:04 -07:00
qconfig.py Un-ignore F403 in .flake8 (#55838) 2021-04-13 09:24:07 -07:00
quant_type.py [quant][graphmode][fx] custom_module support static/dynamic/weight_only quant (#46786) 2020-10-27 21:41:33 -07:00
quantization_mappings.py fix nn.MHA scriptability (#58727) 2021-05-26 15:29:49 -07:00
quantize_fx.py fx quant: enable qconfig_dict to target function invocations by order (#59605) 2021-06-11 08:53:40 -07:00
quantize_jit.py Enable the quantization on XPU devices (#54857) 2021-05-20 17:02:13 -07:00
quantize.py [quant][eager][fix] Fix a typo in convert function in eager mode quantization (#59571) 2021-06-08 10:24:22 -07:00
stubs.py type check for torch.quantization.stubs (#46475) 2020-10-16 15:34:23 -07:00
utils.py [quant][graphmode][fx] Fix a condition check for CopyNode (#53585) 2021-03-11 09:32:20 -08:00