mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR enables gpu only quantization, best used with is_reference since there are not many gpu kernels for ops as of now. This PR mainly changes how qconfigs and their obs constructors operate once they on modules qconfig. The function add_module_to_qconfig_obs_ctr takes the obs constructors on the original qconfig, and configures them so that when invoked, the created obs will be on whatever device the module occupies. (Once observers are created, module.to(device) is already setup so that it moves any observers). To do this, a new method and a few small chanegs were added to the _PartialWrapper class that our observers already use to create constructors (without changing the existing functionality). These changes work in concert with changes to the prepare flow such that when the qconfigs are propagated to the moduels (in quantize.py and qconfig_utils.py) they are configured using add_module_to_qconfig_obs_ctr. Ideally this would work on other models but the is_reference support for a lot of modules isn't there yet, those tests should be added in a future PR Test Plan: python test/test_quantization.py TestQuantizeFxModels.test_static_gpu_convert_basic python test/test_quantization.py TestQuantizeFxModels.test_switch_device_prepare_convert python test/test_quantization.py TestQuantizeFxModels.test_prepare_serialize_switch_device_convert python test/test_quantization.py TestQuantizeFx.test_qconfig_precedence Reviewed By: vkuzo Differential Revision: D29684114 fbshipit-source-id: 19fefb8e1998eaf212723e836276ccf39467f2e7
60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
import torch
|
|
from collections import OrderedDict
|
|
from typing import Dict, Any, Tuple, List, Optional
|
|
from torch.fx.graph import (
|
|
Node,
|
|
)
|
|
from .quantization_types import Pattern
|
|
from ..qconfig import QConfigAny
|
|
# from .quantization_patterns import BinaryOpQuantizeHandler
|
|
|
|
|
|
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
|
|
QuantizeHandler = Any
|
|
|
|
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
|
QConfigAny]
|
|
|
|
# pattern for conv bn fusion
|
|
DEFAULT_FUSION_PATTERNS = OrderedDict()
|
|
def register_fusion_pattern(pattern):
|
|
def insert(fn):
|
|
DEFAULT_FUSION_PATTERNS[pattern] = fn
|
|
return fn
|
|
return insert
|
|
|
|
def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
|
|
return DEFAULT_FUSION_PATTERNS
|
|
|
|
DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
|
|
# a map from pattern to activation_post_process(observer/fake_quant) consstructor for output activation
|
|
# e.g. pattern: torch.sigmoid,
|
|
# output_activation_post_process: default_affine_fixed_qparam_fake_quant
|
|
DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP = dict()
|
|
|
|
# Register pattern for both static quantization and qat
|
|
def register_quant_pattern(pattern, output_activation_post_process=None):
|
|
def insert(fn):
|
|
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
|
|
if output_activation_post_process is not None:
|
|
DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP[pattern] = output_activation_post_process
|
|
return fn
|
|
return insert
|
|
|
|
# Get patterns for both static quantization and qat
|
|
def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
|
|
return DEFAULT_QUANTIZATION_PATTERNS
|
|
|
|
# a map from pattern to output activation post process constructor
|
|
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
|
|
def get_default_output_activation_post_process_map() -> Dict[Pattern, torch.quantization.observer.ObserverBase]:
|
|
return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP
|
|
|
|
|
|
# Example use of register pattern function:
|
|
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
|
# class ConvBNReLUFusion():
|
|
# def __init__(...):
|
|
# ...
|
|
#
|