mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: made _DEFAULT_FUSION_PATTERNS, _register_fusion_pattern, _DEFAULT_QUANTIZATION_PATTERNS, _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP, _DEFAULT_OUTPUT_OBSERVER_MAP, _register_quant_pattern, _sorted_patterns_dict private Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D41015537](https://our.internmc.facebook.com/intern/diff/D41015537) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88397 Approved by: https://github.com/jcaip
88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
from collections import OrderedDict
|
|
from typing import Dict, Any
|
|
from torch.ao.quantization.utils import Pattern
|
|
from ..fake_quantize import FixedQParamsFakeQuantize
|
|
from ..observer import ObserverBase
|
|
import copy
|
|
|
|
__all__ = [
|
|
"get_default_fusion_patterns",
|
|
"get_default_quant_patterns",
|
|
"get_default_output_activation_post_process_map",
|
|
]
|
|
|
|
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
|
|
QuantizeHandler = Any
|
|
|
|
# pattern for conv bn fusion
|
|
_DEFAULT_FUSION_PATTERNS = OrderedDict()
|
|
def _register_fusion_pattern(pattern):
|
|
def insert(fn):
|
|
_DEFAULT_FUSION_PATTERNS[pattern] = fn
|
|
return fn
|
|
return insert
|
|
|
|
def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
|
|
return copy.copy(_DEFAULT_FUSION_PATTERNS)
|
|
|
|
_DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
|
|
|
|
# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
|
|
# e.g. pattern: torch.sigmoid,
|
|
# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
|
|
_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP = {}
|
|
_DEFAULT_OUTPUT_OBSERVER_MAP = {}
|
|
|
|
# Register pattern for both static quantization and qat
|
|
def _register_quant_pattern(pattern, fixed_qparams_observer=None):
|
|
def insert(fn):
|
|
_DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
|
|
if fixed_qparams_observer is not None:
|
|
_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer)
|
|
_DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer
|
|
return fn
|
|
return insert
|
|
|
|
# Get patterns for both static quantization and qat
|
|
def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
|
|
return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS)
|
|
|
|
# a map from pattern to output activation post process constructor
|
|
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
|
|
def get_default_output_activation_post_process_map(is_training) -> Dict[Pattern, ObserverBase]:
|
|
if is_training:
|
|
return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
|
|
else:
|
|
return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP)
|
|
|
|
# Example use of register pattern function:
|
|
# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
|
# class ConvOrLinearBNReLUFusion():
|
|
# def __init__(...):
|
|
# ...
|
|
#
|
|
|
|
def _sorted_patterns_dict(patterns_dict: Dict[Pattern, QuantizeHandler]) -> Dict[Pattern, QuantizeHandler]:
|
|
"""
|
|
Return a sorted version of the patterns dictionary such that longer patterns are matched first,
|
|
e.g. match (F.relu, F.linear) before F.relu.
|
|
This works for current use cases, but we may need to have a more clever way to sort
|
|
things to address more complex patterns
|
|
"""
|
|
|
|
def get_len(pattern):
|
|
""" this will calculate the length of the pattern by counting all the entries
|
|
in the pattern.
|
|
this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before
|
|
(nn.BatchNorm, nn.Conv2d) so that we can match the former first
|
|
"""
|
|
len = 0
|
|
if isinstance(pattern, tuple):
|
|
for item in pattern:
|
|
len += get_len(item)
|
|
else:
|
|
len += 1
|
|
return len
|
|
|
|
return OrderedDict(sorted(patterns_dict.items(), key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1))
|