mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][fx][improvement] Renamed default_affine_fixed_qparams_observer and default_symmetric_fixed_qparams_observer (#76637)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76637 The previous naming convention `default_affine_fixed_qparams_observer` and `default_symmetric_fixed_qparams_observer` were uninformative, and users had to read the definition in order to understand what these observers are. The new naming convention reveals information about the range of the observers The analogous changes were also made for `default_symmetric_fixed_qparams_fake_quant` and `default_affine_fixed_qparams_fake_quant` Test Plan: ``` python test/test_quantization.py ``` ``` python test/test_quantization.py ``` Differential Revision: D36054169 D36054169 Reviewed By: vkuzo Pulled By: dzdang fbshipit-source-id: 215f7786a4b7abda7327f17cc61735697ec5cca9 (cherry picked from commit 21a4e6eda4467c8adca7fd534a506a14e975f9cf)
This commit is contained in:
parent
4cdec1a79e
commit
e2aa28a2d0
|
|
@ -44,6 +44,8 @@
|
|||
"MovingAveragePerChannelMinMaxObserver",
|
||||
"Tuple",
|
||||
"abstractmethod",
|
||||
"default_fixed_qparams_range_0to1_fake_quant",
|
||||
"default_fixed_qparams_range_0to1_observer",
|
||||
"default_affine_fixed_qparams_fake_quant",
|
||||
"default_affine_fixed_qparams_observer",
|
||||
"default_dynamic_fake_quant",
|
||||
|
|
@ -55,6 +57,8 @@
|
|||
"default_fused_wt_fake_quant",
|
||||
"default_histogram_fake_quant",
|
||||
"default_per_channel_weight_fake_quant",
|
||||
"default_fixed_qparams_range_neg1to1_fake_quant",
|
||||
"default_fixed_qparams_range_neg1to1_observer",
|
||||
"default_symmetric_fixed_qparams_fake_quant",
|
||||
"default_symmetric_fixed_qparams_observer",
|
||||
"default_weight_fake_quant",
|
||||
|
|
@ -86,6 +90,8 @@
|
|||
"FixedQParamsFakeQuantize",
|
||||
"List",
|
||||
"ObservationType",
|
||||
"default_fixed_qparams_range_0to1_observer",
|
||||
"default_fixed_qparams_range_neg1to1_observer",
|
||||
"default_affine_fixed_qparams_observer",
|
||||
"default_symmetric_fixed_qparams_observer",
|
||||
"fuse_conv_bn",
|
||||
|
|
@ -445,6 +451,8 @@
|
|||
"QuantStub",
|
||||
"Set",
|
||||
"Union",
|
||||
"default_fixed_qparams_range_0to1_fake_quant",
|
||||
"default_fixed_qparams_range_neg1to1_fake_quant",
|
||||
"default_affine_fixed_qparams_fake_quant",
|
||||
"default_symmetric_fixed_qparams_fake_quant",
|
||||
"get_combined_dict",
|
||||
|
|
@ -2872,6 +2880,7 @@
|
|||
"convert",
|
||||
"convert_dynamic_jit",
|
||||
"convert_jit",
|
||||
"default_fixed_qparams_range_0to1_fake_quant",
|
||||
"default_affine_fixed_qparams_fake_quant",
|
||||
"default_debug_observer",
|
||||
"default_dynamic_quant_observer",
|
||||
|
|
@ -2886,6 +2895,7 @@
|
|||
"default_per_channel_weight_fake_quant",
|
||||
"default_per_channel_weight_observer",
|
||||
"default_placeholder_observer",
|
||||
"default_fixed_qparams_range_neg1to1_fake_quant",
|
||||
"default_symmetric_fixed_qparams_fake_quant",
|
||||
"default_weight_fake_quant",
|
||||
"default_weight_observer",
|
||||
|
|
@ -2938,6 +2948,7 @@
|
|||
"FakeQuantizeBase",
|
||||
"FixedQParamsFakeQuantize",
|
||||
"FusedMovingAvgObsFakeQuantize",
|
||||
"default_fixed_qparams_range_0to1_fake_quant",
|
||||
"default_affine_fixed_qparams_fake_quant",
|
||||
"default_fake_quant",
|
||||
"default_fused_act_fake_quant",
|
||||
|
|
@ -2945,6 +2956,7 @@
|
|||
"default_fused_wt_fake_quant",
|
||||
"default_histogram_fake_quant",
|
||||
"default_per_channel_weight_fake_quant",
|
||||
"default_fixed_qparams_range_neg1to1_fake_quant",
|
||||
"default_symmetric_fixed_qparams_fake_quant",
|
||||
"default_weight_fake_quant",
|
||||
"disable_fake_quant",
|
||||
|
|
|
|||
|
|
@ -111,8 +111,8 @@ class TestAOMigration(AOMigrationTestCase):
|
|||
'FusedMovingAvgObsFakeQuantize',
|
||||
'default_fake_quant',
|
||||
'default_weight_fake_quant',
|
||||
'default_symmetric_fixed_qparams_fake_quant',
|
||||
'default_affine_fixed_qparams_fake_quant',
|
||||
'default_fixed_qparams_range_neg1to1_fake_quant',
|
||||
'default_fixed_qparams_range_0to1_fake_quant',
|
||||
'default_per_channel_weight_fake_quant',
|
||||
'default_histogram_fake_quant',
|
||||
'default_fused_act_fake_quant',
|
||||
|
|
|
|||
|
|
@ -81,8 +81,8 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
|||
'FusedMovingAvgObsFakeQuantize',
|
||||
'default_fake_quant',
|
||||
'default_weight_fake_quant',
|
||||
'default_symmetric_fixed_qparams_fake_quant',
|
||||
'default_affine_fixed_qparams_fake_quant',
|
||||
'default_fixed_qparams_range_neg1to1_fake_quant',
|
||||
'default_fixed_qparams_range_0to1_fake_quant',
|
||||
'default_per_channel_weight_fake_quant',
|
||||
'default_histogram_fake_quant',
|
||||
'default_fused_act_fake_quant',
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from torch.ao.quantization import (
|
|||
FakeQuantize,
|
||||
MovingAverageMinMaxObserver,
|
||||
default_observer,
|
||||
default_affine_fixed_qparams_fake_quant,
|
||||
default_fixed_qparams_range_0to1_fake_quant,
|
||||
)
|
||||
|
||||
from torch.ao.quantization._learnable_fake_quantize import _LearnableFakeQuantize
|
||||
|
|
@ -544,7 +544,7 @@ class TestFakeQuantizeOps(TestCase):
|
|||
def test_fixed_qparams_fq_module(self, device, X):
|
||||
X, (scale, zero_point, torch_type) = X
|
||||
X = to_tensor(X, device)
|
||||
fq_module = default_affine_fixed_qparams_fake_quant()
|
||||
fq_module = default_fixed_qparams_range_0to1_fake_quant()
|
||||
fq_module.to(device)
|
||||
fixed_scale = fq_module.scale.clone()
|
||||
fixed_zero_point = fq_module.zero_point.clone()
|
||||
|
|
|
|||
|
|
@ -83,13 +83,13 @@ from torch.ao.quantization.fx.pattern_utils import (
|
|||
from torch.ao.quantization.fx.utils import NodeInfo
|
||||
|
||||
from torch.ao.quantization.fake_quantize import (
|
||||
default_affine_fixed_qparams_fake_quant,
|
||||
default_symmetric_fixed_qparams_fake_quant,
|
||||
default_fixed_qparams_range_0to1_fake_quant,
|
||||
default_fixed_qparams_range_neg1to1_fake_quant,
|
||||
)
|
||||
|
||||
from torch.ao.quantization.observer import (
|
||||
default_affine_fixed_qparams_observer,
|
||||
default_symmetric_fixed_qparams_observer,
|
||||
default_fixed_qparams_range_0to1_observer,
|
||||
default_fixed_qparams_range_neg1to1_observer,
|
||||
)
|
||||
|
||||
# test utils
|
||||
|
|
@ -4085,11 +4085,11 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
class DummyQuant():
|
||||
pass
|
||||
|
||||
@register_quant_pattern("dummy_quant2", default_affine_fixed_qparams_observer)
|
||||
@register_quant_pattern("dummy_quant2", default_fixed_qparams_range_0to1_observer)
|
||||
class DummyQuant2():
|
||||
pass
|
||||
|
||||
@register_quant_pattern("dummy_quant3", default_symmetric_fixed_qparams_observer)
|
||||
@register_quant_pattern("dummy_quant3", default_fixed_qparams_range_neg1to1_observer)
|
||||
class DummyQuant3():
|
||||
pass
|
||||
|
||||
|
|
@ -4097,17 +4097,17 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
self.assertEqual(DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"], DummyQuant)
|
||||
self.assertEqual(DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"], DummyQuant2)
|
||||
self.assertEqual(DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"], DummyQuant3)
|
||||
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_affine_fixed_qparams_observer)
|
||||
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_symmetric_fixed_qparams_observer)
|
||||
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_observer)
|
||||
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_observer)
|
||||
self._assertFixedQParamsFakeQuantizeEqual(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"],
|
||||
default_affine_fixed_qparams_fake_quant)
|
||||
default_fixed_qparams_range_0to1_fake_quant)
|
||||
self._assertFixedQParamsFakeQuantizeEqual(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"],
|
||||
default_symmetric_fixed_qparams_fake_quant)
|
||||
default_fixed_qparams_range_neg1to1_fake_quant)
|
||||
output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True)
|
||||
output_observer_map = get_default_output_activation_post_process_map(is_training=False)
|
||||
self.assertEqual(output_observer_map.get("dummy_quant3"), default_symmetric_fixed_qparams_observer)
|
||||
self.assertEqual(output_observer_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_observer)
|
||||
self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"),
|
||||
default_symmetric_fixed_qparams_fake_quant)
|
||||
default_fixed_qparams_range_neg1to1_fake_quant)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ import torch.nn.intrinsic.qat as nniqat
|
|||
import torch.nn.qat as nnqat
|
||||
import torch.nn.quantized._reference as nnqr
|
||||
from ..observer import (
|
||||
default_affine_fixed_qparams_observer,
|
||||
default_symmetric_fixed_qparams_observer,
|
||||
default_fixed_qparams_range_0to1_observer,
|
||||
default_fixed_qparams_range_neg1to1_observer,
|
||||
)
|
||||
from ..fake_quantize import FixedQParamsFakeQuantize
|
||||
from ..fuser_method_mappings import (
|
||||
|
|
@ -484,19 +484,19 @@ def _get_binary_op_configs(dtype_configs):
|
|||
def _get_fixed_qparams_op_configs():
|
||||
fixed_qparams_op_configs = []
|
||||
for fixed_qparam_op, output_observer in [
|
||||
(torch.nn.Hardsigmoid, default_affine_fixed_qparams_observer),
|
||||
(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_observer),
|
||||
("hardsigmoid", default_affine_fixed_qparams_observer),
|
||||
("hardsigmoid_", default_affine_fixed_qparams_observer),
|
||||
(torch.nn.Sigmoid, default_affine_fixed_qparams_observer),
|
||||
(torch.sigmoid, default_affine_fixed_qparams_observer),
|
||||
("sigmoid", default_affine_fixed_qparams_observer),
|
||||
("sigmoid_", default_affine_fixed_qparams_observer),
|
||||
(torch.nn.Tanh, default_symmetric_fixed_qparams_observer),
|
||||
(torch.tanh, default_symmetric_fixed_qparams_observer),
|
||||
("tanh", default_symmetric_fixed_qparams_observer),
|
||||
("tanh_", default_symmetric_fixed_qparams_observer),
|
||||
(torch.nn.Softmax, default_affine_fixed_qparams_observer),
|
||||
(torch.nn.Hardsigmoid, default_fixed_qparams_range_0to1_observer),
|
||||
(torch.nn.functional.hardsigmoid, default_fixed_qparams_range_0to1_observer),
|
||||
("hardsigmoid", default_fixed_qparams_range_0to1_observer),
|
||||
("hardsigmoid_", default_fixed_qparams_range_0to1_observer),
|
||||
(torch.nn.Sigmoid, default_fixed_qparams_range_0to1_observer),
|
||||
(torch.sigmoid, default_fixed_qparams_range_0to1_observer),
|
||||
("sigmoid", default_fixed_qparams_range_0to1_observer),
|
||||
("sigmoid_", default_fixed_qparams_range_0to1_observer),
|
||||
(torch.nn.Tanh, default_fixed_qparams_range_neg1to1_observer),
|
||||
(torch.tanh, default_fixed_qparams_range_neg1to1_observer),
|
||||
("tanh", default_fixed_qparams_range_neg1to1_observer),
|
||||
("tanh_", default_fixed_qparams_range_neg1to1_observer),
|
||||
(torch.nn.Softmax, default_fixed_qparams_range_0to1_observer),
|
||||
]:
|
||||
fixed_qparams_op_configs.append({
|
||||
"pattern": fixed_qparam_op,
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from torch.ao.quantization.observer import (
|
|||
HistogramObserver,
|
||||
MovingAveragePerChannelMinMaxObserver,
|
||||
FixedQParamsObserver,
|
||||
default_affine_fixed_qparams_observer,
|
||||
default_symmetric_fixed_qparams_observer,
|
||||
default_fixed_qparams_range_0to1_observer,
|
||||
default_fixed_qparams_range_neg1to1_observer,
|
||||
_with_args,
|
||||
)
|
||||
import re
|
||||
|
|
@ -352,8 +352,15 @@ default_dynamic_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMax
|
|||
Default dynamic fake_quant for activations.
|
||||
"""
|
||||
|
||||
default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(observer=default_symmetric_fixed_qparams_observer)
|
||||
default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(observer=default_affine_fixed_qparams_observer)
|
||||
default_fixed_qparams_range_neg1to1_fake_quant = (
|
||||
FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer)
|
||||
)
|
||||
default_fixed_qparams_range_0to1_fake_quant = (
|
||||
FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
|
||||
)
|
||||
# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
|
||||
default_symmetric_fixed_qparams_fake_quant = default_fixed_qparams_range_neg1to1_fake_quant
|
||||
default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
|
||||
|
||||
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
||||
quant_min=-128,
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
|
|||
|
||||
# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
|
||||
# e.g. pattern: torch.sigmoid,
|
||||
# output_activation_post_process: default_affine_fixed_qparams_fake_quant
|
||||
# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
|
||||
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP = dict()
|
||||
DEFAULT_OUTPUT_OBSERVER_MAP = dict()
|
||||
|
||||
|
|
|
|||
|
|
@ -1495,10 +1495,14 @@ Default observer for a floating point zero-point and 4 bit activations.
|
|||
|
||||
# TODO(future PR): remove these defaults and enforce activation functions
|
||||
# to explicitly specify their output range
|
||||
default_symmetric_fixed_qparams_observer = FixedQParamsObserver.with_args(
|
||||
default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args(
|
||||
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
|
||||
default_affine_fixed_qparams_observer = FixedQParamsObserver.with_args(
|
||||
default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args(
|
||||
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
|
||||
# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
|
||||
default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer
|
||||
default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer
|
||||
|
||||
"""
|
||||
Default observers for fixed qparams operations.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ from typing import Optional, Union, Dict, Set, Callable, Any
|
|||
import torch.ao.nn as ao_nn
|
||||
from torch.ao.quantization.stubs import QuantStub, DeQuantStub
|
||||
from torch.ao.quantization.fake_quantize import (
|
||||
default_affine_fixed_qparams_fake_quant,
|
||||
default_symmetric_fixed_qparams_fake_quant,
|
||||
default_fixed_qparams_range_0to1_fake_quant,
|
||||
default_fixed_qparams_range_neg1to1_fake_quant,
|
||||
)
|
||||
from torch.ao.quantization.utils import get_combined_dict
|
||||
from torch.nn.utils.parametrize import type_before_parametrizations
|
||||
|
|
@ -156,10 +156,10 @@ DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callab
|
|||
|
||||
# mapping from module to output activation post process class
|
||||
DEFAULT_MODULE_TO_ACT_POST_PROCESS : Dict[Callable, Callable] = {
|
||||
nn.Hardsigmoid: default_affine_fixed_qparams_fake_quant,
|
||||
nn.Sigmoid: default_affine_fixed_qparams_fake_quant,
|
||||
nn.Softmax: default_affine_fixed_qparams_fake_quant,
|
||||
nn.Tanh: default_symmetric_fixed_qparams_fake_quant,
|
||||
nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant,
|
||||
nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant,
|
||||
nn.Softmax: default_fixed_qparams_range_0to1_fake_quant,
|
||||
nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant,
|
||||
}
|
||||
|
||||
# Default map for swapping float module to static sparse quantized ones
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ _all__ = [
|
|||
'default_per_channel_weight_observer',
|
||||
# FakeQuantize (for qat)
|
||||
'default_fake_quant', 'default_weight_fake_quant',
|
||||
'default_symmetric_fixed_qparams_fake_quant',
|
||||
'default_affine_fixed_qparams_fake_quant',
|
||||
'default_fixed_qparams_range_neg1to1_fake_quant',
|
||||
'default_fixed_qparams_range_0to1_fake_quant',
|
||||
'default_per_channel_weight_fake_quant',
|
||||
'default_histogram_fake_quant',
|
||||
# QConfig
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ from torch.ao.quantization.fake_quantize import (
|
|||
FusedMovingAvgObsFakeQuantize,
|
||||
default_fake_quant,
|
||||
default_weight_fake_quant,
|
||||
default_symmetric_fixed_qparams_fake_quant,
|
||||
default_affine_fixed_qparams_fake_quant,
|
||||
default_fixed_qparams_range_neg1to1_fake_quant,
|
||||
default_fixed_qparams_range_0to1_fake_quant,
|
||||
default_per_channel_weight_fake_quant,
|
||||
default_histogram_fake_quant,
|
||||
default_fused_act_fake_quant,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user