mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76637 The previous naming convention `default_affine_fixed_qparams_observer` and `default_symmetric_fixed_qparams_observer` were uninformative, and users had to read the definition in order to understand what these observers are. The new naming convention reveals information about the range of the observers The analogous changes were also made for `default_symmetric_fixed_qparams_fake_quant` and `default_affine_fixed_qparams_fake_quant` Test Plan: ``` python test/test_quantization.py ``` ``` python test/test_quantization.py ``` Differential Revision: D36054169 D36054169 Reviewed By: vkuzo Pulled By: dzdang fbshipit-source-id: 215f7786a4b7abda7327f17cc61735697ec5cca9 (cherry picked from commit 21a4e6eda4467c8adca7fd534a506a14e975f9cf)
90 lines
3.4 KiB
Python
90 lines
3.4 KiB
Python
from collections import OrderedDict
|
|
from typing import Dict, Any, Tuple, List, Optional
|
|
from torch.fx.graph import (
|
|
Node,
|
|
)
|
|
from torch.ao.quantization.quantization_types import Pattern
|
|
from ..qconfig import QConfigAny
|
|
from ..fake_quantize import FixedQParamsFakeQuantize
|
|
# from .quantization_patterns import BinaryOpQuantizeHandler
|
|
from ..observer import ObserverBase
|
|
import copy
|
|
|
|
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
|
|
QuantizeHandler = Any
|
|
|
|
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
|
QConfigAny]
|
|
|
|
# pattern for conv bn fusion
|
|
DEFAULT_FUSION_PATTERNS = OrderedDict()
|
|
def register_fusion_pattern(pattern):
|
|
def insert(fn):
|
|
DEFAULT_FUSION_PATTERNS[pattern] = fn
|
|
return fn
|
|
return insert
|
|
|
|
def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
|
|
return copy.copy(DEFAULT_FUSION_PATTERNS)
|
|
|
|
DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
|
|
|
|
# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
|
|
# e.g. pattern: torch.sigmoid,
|
|
# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
|
|
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP = dict()
|
|
DEFAULT_OUTPUT_OBSERVER_MAP = dict()
|
|
|
|
# Register pattern for both static quantization and qat
|
|
def register_quant_pattern(pattern, fixed_qparams_observer=None):
|
|
def insert(fn):
|
|
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
|
|
if fixed_qparams_observer is not None:
|
|
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer)
|
|
DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer
|
|
return fn
|
|
return insert
|
|
|
|
# Get patterns for both static quantization and qat
|
|
def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
|
|
return copy.copy(DEFAULT_QUANTIZATION_PATTERNS)
|
|
|
|
# a map from pattern to output activation post process constructor
|
|
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
|
|
def get_default_output_activation_post_process_map(is_training) -> Dict[Pattern, ObserverBase]:
|
|
if is_training:
|
|
return copy.copy(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
|
|
else:
|
|
return copy.copy(DEFAULT_OUTPUT_OBSERVER_MAP)
|
|
|
|
# Example use of register pattern function:
|
|
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
|
# class ConvOrLinearBNReLUFusion():
|
|
# def __init__(...):
|
|
# ...
|
|
#
|
|
|
|
def sorted_patterns_dict(patterns_dict: Dict[Pattern, QuantizeHandler]) -> Dict[Pattern, QuantizeHandler]:
|
|
"""
|
|
Return a sorted version of the patterns dictionary such that longer patterns are matched first,
|
|
e.g. match (F.relu, F.linear) before F.relu.
|
|
This works for current use cases, but we may need to have a more clever way to sort
|
|
things to address more complex patterns
|
|
"""
|
|
|
|
def get_len(pattern):
|
|
""" this will calculate the length of the pattern by counting all the entries
|
|
in the pattern.
|
|
this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before
|
|
(nn.BatchNorm, nn.Conv2d) so that we can match the former first
|
|
"""
|
|
len = 0
|
|
if isinstance(pattern, tuple):
|
|
for item in pattern:
|
|
len += get_len(item)
|
|
else:
|
|
len += 1
|
|
return len
|
|
|
|
return OrderedDict(sorted(patterns_dict.items(), key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1))
|