[quant][fx][improvement] Renamed default_affine_fixed_qparams_observer and default_symmetric_fixed_qparams_observer (#76637)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76637

The previous naming convention `default_affine_fixed_qparams_observer`
and `default_symmetric_fixed_qparams_observer` were uninformative, and users had to read
the definition in order to understand what these observers are. The new
naming convention reveals information about the range of the observers

The analogous changes were also made for
`default_symmetric_fixed_qparams_fake_quant` and
`default_affine_fixed_qparams_fake_quant`

Test Plan:
```
python test/test_quantization.py
```

```
python test/test_quantization.py
```

Differential Revision:
D36054169
D36054169

Reviewed By: vkuzo

Pulled By: dzdang

fbshipit-source-id: 215f7786a4b7abda7327f17cc61735697ec5cca9
(cherry picked from commit 21a4e6eda4467c8adca7fd534a506a14e975f9cf)
This commit is contained in:
dzdang 2022-05-03 19:32:55 -07:00 committed by PyTorch MergeBot
parent 4cdec1a79e
commit e2aa28a2d0
12 changed files with 73 additions and 50 deletions

View File

@ -44,6 +44,8 @@
"MovingAveragePerChannelMinMaxObserver",
"Tuple",
"abstractmethod",
"default_fixed_qparams_range_0to1_fake_quant",
"default_fixed_qparams_range_0to1_observer",
"default_affine_fixed_qparams_fake_quant",
"default_affine_fixed_qparams_observer",
"default_dynamic_fake_quant",
@ -55,6 +57,8 @@
"default_fused_wt_fake_quant",
"default_histogram_fake_quant",
"default_per_channel_weight_fake_quant",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_fixed_qparams_range_neg1to1_observer",
"default_symmetric_fixed_qparams_fake_quant",
"default_symmetric_fixed_qparams_observer",
"default_weight_fake_quant",
@ -86,6 +90,8 @@
"FixedQParamsFakeQuantize",
"List",
"ObservationType",
"default_fixed_qparams_range_0to1_observer",
"default_fixed_qparams_range_neg1to1_observer",
"default_affine_fixed_qparams_observer",
"default_symmetric_fixed_qparams_observer",
"fuse_conv_bn",
@ -445,6 +451,8 @@
"QuantStub",
"Set",
"Union",
"default_fixed_qparams_range_0to1_fake_quant",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_affine_fixed_qparams_fake_quant",
"default_symmetric_fixed_qparams_fake_quant",
"get_combined_dict",
@ -2872,6 +2880,7 @@
"convert",
"convert_dynamic_jit",
"convert_jit",
"default_fixed_qparams_range_0to1_fake_quant",
"default_affine_fixed_qparams_fake_quant",
"default_debug_observer",
"default_dynamic_quant_observer",
@ -2886,6 +2895,7 @@
"default_per_channel_weight_fake_quant",
"default_per_channel_weight_observer",
"default_placeholder_observer",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_symmetric_fixed_qparams_fake_quant",
"default_weight_fake_quant",
"default_weight_observer",
@ -2938,6 +2948,7 @@
"FakeQuantizeBase",
"FixedQParamsFakeQuantize",
"FusedMovingAvgObsFakeQuantize",
"default_fixed_qparams_range_0to1_fake_quant",
"default_affine_fixed_qparams_fake_quant",
"default_fake_quant",
"default_fused_act_fake_quant",
@ -2945,6 +2956,7 @@
"default_fused_wt_fake_quant",
"default_histogram_fake_quant",
"default_per_channel_weight_fake_quant",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_symmetric_fixed_qparams_fake_quant",
"default_weight_fake_quant",
"disable_fake_quant",

View File

@ -111,8 +111,8 @@ class TestAOMigration(AOMigrationTestCase):
'FusedMovingAvgObsFakeQuantize',
'default_fake_quant',
'default_weight_fake_quant',
'default_symmetric_fixed_qparams_fake_quant',
'default_affine_fixed_qparams_fake_quant',
'default_fixed_qparams_range_neg1to1_fake_quant',
'default_fixed_qparams_range_0to1_fake_quant',
'default_per_channel_weight_fake_quant',
'default_histogram_fake_quant',
'default_fused_act_fake_quant',

View File

@ -81,8 +81,8 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'FusedMovingAvgObsFakeQuantize',
'default_fake_quant',
'default_weight_fake_quant',
'default_symmetric_fixed_qparams_fake_quant',
'default_affine_fixed_qparams_fake_quant',
'default_fixed_qparams_range_neg1to1_fake_quant',
'default_fixed_qparams_range_0to1_fake_quant',
'default_per_channel_weight_fake_quant',
'default_histogram_fake_quant',
'default_fused_act_fake_quant',

View File

@ -7,7 +7,7 @@ from torch.ao.quantization import (
FakeQuantize,
MovingAverageMinMaxObserver,
default_observer,
default_affine_fixed_qparams_fake_quant,
default_fixed_qparams_range_0to1_fake_quant,
)
from torch.ao.quantization._learnable_fake_quantize import _LearnableFakeQuantize
@ -544,7 +544,7 @@ class TestFakeQuantizeOps(TestCase):
def test_fixed_qparams_fq_module(self, device, X):
X, (scale, zero_point, torch_type) = X
X = to_tensor(X, device)
fq_module = default_affine_fixed_qparams_fake_quant()
fq_module = default_fixed_qparams_range_0to1_fake_quant()
fq_module.to(device)
fixed_scale = fq_module.scale.clone()
fixed_zero_point = fq_module.zero_point.clone()

View File

@ -83,13 +83,13 @@ from torch.ao.quantization.fx.pattern_utils import (
from torch.ao.quantization.fx.utils import NodeInfo
from torch.ao.quantization.fake_quantize import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
default_fixed_qparams_range_0to1_fake_quant,
default_fixed_qparams_range_neg1to1_fake_quant,
)
from torch.ao.quantization.observer import (
default_affine_fixed_qparams_observer,
default_symmetric_fixed_qparams_observer,
default_fixed_qparams_range_0to1_observer,
default_fixed_qparams_range_neg1to1_observer,
)
# test utils
@ -4085,11 +4085,11 @@ class TestQuantizeFx(QuantizationTestCase):
class DummyQuant():
pass
@register_quant_pattern("dummy_quant2", default_affine_fixed_qparams_observer)
@register_quant_pattern("dummy_quant2", default_fixed_qparams_range_0to1_observer)
class DummyQuant2():
pass
@register_quant_pattern("dummy_quant3", default_symmetric_fixed_qparams_observer)
@register_quant_pattern("dummy_quant3", default_fixed_qparams_range_neg1to1_observer)
class DummyQuant3():
pass
@ -4097,17 +4097,17 @@ class TestQuantizeFx(QuantizationTestCase):
self.assertEqual(DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"], DummyQuant)
self.assertEqual(DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"], DummyQuant2)
self.assertEqual(DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"], DummyQuant3)
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_affine_fixed_qparams_observer)
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_symmetric_fixed_qparams_observer)
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_observer)
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_observer)
self._assertFixedQParamsFakeQuantizeEqual(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"],
default_affine_fixed_qparams_fake_quant)
default_fixed_qparams_range_0to1_fake_quant)
self._assertFixedQParamsFakeQuantizeEqual(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"],
default_symmetric_fixed_qparams_fake_quant)
default_fixed_qparams_range_neg1to1_fake_quant)
output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True)
output_observer_map = get_default_output_activation_post_process_map(is_training=False)
self.assertEqual(output_observer_map.get("dummy_quant3"), default_symmetric_fixed_qparams_observer)
self.assertEqual(output_observer_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_observer)
self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"),
default_symmetric_fixed_qparams_fake_quant)
default_fixed_qparams_range_neg1to1_fake_quant)

View File

@ -10,8 +10,8 @@ import torch.nn.intrinsic.qat as nniqat
import torch.nn.qat as nnqat
import torch.nn.quantized._reference as nnqr
from ..observer import (
default_affine_fixed_qparams_observer,
default_symmetric_fixed_qparams_observer,
default_fixed_qparams_range_0to1_observer,
default_fixed_qparams_range_neg1to1_observer,
)
from ..fake_quantize import FixedQParamsFakeQuantize
from ..fuser_method_mappings import (
@ -484,19 +484,19 @@ def _get_binary_op_configs(dtype_configs):
def _get_fixed_qparams_op_configs():
fixed_qparams_op_configs = []
for fixed_qparam_op, output_observer in [
(torch.nn.Hardsigmoid, default_affine_fixed_qparams_observer),
(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_observer),
("hardsigmoid", default_affine_fixed_qparams_observer),
("hardsigmoid_", default_affine_fixed_qparams_observer),
(torch.nn.Sigmoid, default_affine_fixed_qparams_observer),
(torch.sigmoid, default_affine_fixed_qparams_observer),
("sigmoid", default_affine_fixed_qparams_observer),
("sigmoid_", default_affine_fixed_qparams_observer),
(torch.nn.Tanh, default_symmetric_fixed_qparams_observer),
(torch.tanh, default_symmetric_fixed_qparams_observer),
("tanh", default_symmetric_fixed_qparams_observer),
("tanh_", default_symmetric_fixed_qparams_observer),
(torch.nn.Softmax, default_affine_fixed_qparams_observer),
(torch.nn.Hardsigmoid, default_fixed_qparams_range_0to1_observer),
(torch.nn.functional.hardsigmoid, default_fixed_qparams_range_0to1_observer),
("hardsigmoid", default_fixed_qparams_range_0to1_observer),
("hardsigmoid_", default_fixed_qparams_range_0to1_observer),
(torch.nn.Sigmoid, default_fixed_qparams_range_0to1_observer),
(torch.sigmoid, default_fixed_qparams_range_0to1_observer),
("sigmoid", default_fixed_qparams_range_0to1_observer),
("sigmoid_", default_fixed_qparams_range_0to1_observer),
(torch.nn.Tanh, default_fixed_qparams_range_neg1to1_observer),
(torch.tanh, default_fixed_qparams_range_neg1to1_observer),
("tanh", default_fixed_qparams_range_neg1to1_observer),
("tanh_", default_fixed_qparams_range_neg1to1_observer),
(torch.nn.Softmax, default_fixed_qparams_range_0to1_observer),
]:
fixed_qparams_op_configs.append({
"pattern": fixed_qparam_op,

View File

@ -10,8 +10,8 @@ from torch.ao.quantization.observer import (
HistogramObserver,
MovingAveragePerChannelMinMaxObserver,
FixedQParamsObserver,
default_affine_fixed_qparams_observer,
default_symmetric_fixed_qparams_observer,
default_fixed_qparams_range_0to1_observer,
default_fixed_qparams_range_neg1to1_observer,
_with_args,
)
import re
@ -352,8 +352,15 @@ default_dynamic_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMax
Default dynamic fake_quant for activations.
"""
default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(observer=default_symmetric_fixed_qparams_observer)
default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(observer=default_affine_fixed_qparams_observer)
default_fixed_qparams_range_neg1to1_fake_quant = (
FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer)
)
default_fixed_qparams_range_0to1_fake_quant = (
FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
)
# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
default_symmetric_fixed_qparams_fake_quant = default_fixed_qparams_range_neg1to1_fake_quant
default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,

View File

@ -31,7 +31,7 @@ DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
# e.g. pattern: torch.sigmoid,
# output_activation_post_process: default_affine_fixed_qparams_fake_quant
# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP = dict()
DEFAULT_OUTPUT_OBSERVER_MAP = dict()

View File

@ -1495,10 +1495,14 @@ Default observer for a floating point zero-point and 4 bit activations.
# TODO(future PR): remove these defaults and enforce activation functions
# to explicitly specify their output range
default_symmetric_fixed_qparams_observer = FixedQParamsObserver.with_args(
default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args(
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
default_affine_fixed_qparams_observer = FixedQParamsObserver.with_args(
default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer
default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer
"""
Default observers for fixed qparams operations.
"""

View File

@ -19,8 +19,8 @@ from typing import Optional, Union, Dict, Set, Callable, Any
import torch.ao.nn as ao_nn
from torch.ao.quantization.stubs import QuantStub, DeQuantStub
from torch.ao.quantization.fake_quantize import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
default_fixed_qparams_range_0to1_fake_quant,
default_fixed_qparams_range_neg1to1_fake_quant,
)
from torch.ao.quantization.utils import get_combined_dict
from torch.nn.utils.parametrize import type_before_parametrizations
@ -156,10 +156,10 @@ DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callab
# mapping from module to output activation post process class
DEFAULT_MODULE_TO_ACT_POST_PROCESS : Dict[Callable, Callable] = {
nn.Hardsigmoid: default_affine_fixed_qparams_fake_quant,
nn.Sigmoid: default_affine_fixed_qparams_fake_quant,
nn.Softmax: default_affine_fixed_qparams_fake_quant,
nn.Tanh: default_symmetric_fixed_qparams_fake_quant,
nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant,
nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant,
nn.Softmax: default_fixed_qparams_range_0to1_fake_quant,
nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant,
}
# Default map for swapping float module to static sparse quantized ones

View File

@ -49,8 +49,8 @@ _all__ = [
'default_per_channel_weight_observer',
# FakeQuantize (for qat)
'default_fake_quant', 'default_weight_fake_quant',
'default_symmetric_fixed_qparams_fake_quant',
'default_affine_fixed_qparams_fake_quant',
'default_fixed_qparams_range_neg1to1_fake_quant',
'default_fixed_qparams_range_0to1_fake_quant',
'default_per_channel_weight_fake_quant',
'default_histogram_fake_quant',
# QConfig

View File

@ -17,8 +17,8 @@ from torch.ao.quantization.fake_quantize import (
FusedMovingAvgObsFakeQuantize,
default_fake_quant,
default_weight_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
default_affine_fixed_qparams_fake_quant,
default_fixed_qparams_range_neg1to1_fake_quant,
default_fixed_qparams_range_0to1_fake_quant,
default_per_channel_weight_fake_quant,
default_histogram_fake_quant,
default_fused_act_fake_quant,