pytorch/torch/ao/quantization/fx/pattern_utils.py
dzdang e2aa28a2d0 [quant][fx][improvement] Renamed default_affine_fixed_qparams_observer and default_symmetric_fixed_qparams_observer (#76637)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76637

The previous naming convention `default_affine_fixed_qparams_observer`
and `default_symmetric_fixed_qparams_observer` were uninformative, and users had to read
the definition in order to understand what these observers are. The new
naming convention reveals information about the range of the observers

The analogous changes were also made for
`default_symmetric_fixed_qparams_fake_quant` and
`default_affine_fixed_qparams_fake_quant`

Test Plan:
```
python test/test_quantization.py
```

```
python test/test_quantization.py
```

Differential Revision:
D36054169
D36054169

Reviewed By: vkuzo

Pulled By: dzdang

fbshipit-source-id: 215f7786a4b7abda7327f17cc61735697ec5cca9
(cherry picked from commit 21a4e6eda4467c8adca7fd534a506a14e975f9cf)
2022-05-04 02:39:20 +00:00

90 lines
3.4 KiB
Python

from collections import OrderedDict
from typing import Dict, Any, Tuple, List, Optional
from torch.fx.graph import (
Node,
)
from torch.ao.quantization.quantization_types import Pattern
from ..qconfig import QConfigAny
from ..fake_quantize import FixedQParamsFakeQuantize
# from .quantization_patterns import BinaryOpQuantizeHandler
from ..observer import ObserverBase
import copy
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
QuantizeHandler = Any
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# pattern for conv bn fusion
DEFAULT_FUSION_PATTERNS = OrderedDict()
def register_fusion_pattern(pattern):
def insert(fn):
DEFAULT_FUSION_PATTERNS[pattern] = fn
return fn
return insert
def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
return copy.copy(DEFAULT_FUSION_PATTERNS)
DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
# e.g. pattern: torch.sigmoid,
# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP = dict()
DEFAULT_OUTPUT_OBSERVER_MAP = dict()
# Register pattern for both static quantization and qat
def register_quant_pattern(pattern, fixed_qparams_observer=None):
def insert(fn):
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
if fixed_qparams_observer is not None:
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer)
DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer
return fn
return insert
# Get patterns for both static quantization and qat
def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
return copy.copy(DEFAULT_QUANTIZATION_PATTERNS)
# a map from pattern to output activation post process constructor
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
def get_default_output_activation_post_process_map(is_training) -> Dict[Pattern, ObserverBase]:
if is_training:
return copy.copy(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
else:
return copy.copy(DEFAULT_OUTPUT_OBSERVER_MAP)
# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvOrLinearBNReLUFusion():
# def __init__(...):
# ...
#
def sorted_patterns_dict(patterns_dict: Dict[Pattern, QuantizeHandler]) -> Dict[Pattern, QuantizeHandler]:
"""
Return a sorted version of the patterns dictionary such that longer patterns are matched first,
e.g. match (F.relu, F.linear) before F.relu.
This works for current use cases, but we may need to have a more clever way to sort
things to address more complex patterns
"""
def get_len(pattern):
""" this will calculate the length of the pattern by counting all the entries
in the pattern.
this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before
(nn.BatchNorm, nn.Conv2d) so that we can match the former first
"""
len = 0
if isinstance(pattern, tuple):
for item in pattern:
len += get_len(item)
else:
len += 1
return len
return OrderedDict(sorted(patterns_dict.items(), key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1))