pytorch/torch/quantization/fx/pattern_utils.py
Charles David Hernandez 32d0c3e8ee Support for reference convert_fx working on gpu
Summary:
This PR enables gpu only quantization, best used with is_reference since
there are not many gpu kernels for ops as of now.

This PR mainly changes how qconfigs and their obs constructors operate once they
on modules qconfig. The function add_module_to_qconfig_obs_ctr takes the obs constructors on the original
qconfig, and configures them so that when invoked, the created obs will
be on whatever device the module occupies. (Once observers are created,
module.to(device) is already setup so that it moves any observers). To do this,
a new method and a few small chanegs were added to the _PartialWrapper class that
our observers already use to create constructors (without changing the
existing functionality). These changes work in
concert with changes to the prepare flow such that when the qconfigs are
propagated to the moduels (in quantize.py and qconfig_utils.py) they are configured using add_module_to_qconfig_obs_ctr.

Ideally this would work on other models but the is_reference support for
a lot of modules isn't there yet, those tests should be added in a
future PR

Test Plan:
python test/test_quantization.py TestQuantizeFxModels.test_static_gpu_convert_basic

python test/test_quantization.py TestQuantizeFxModels.test_switch_device_prepare_convert

python test/test_quantization.py TestQuantizeFxModels.test_prepare_serialize_switch_device_convert

python test/test_quantization.py TestQuantizeFx.test_qconfig_precedence

Reviewed By: vkuzo

Differential Revision: D29684114

fbshipit-source-id: 19fefb8e1998eaf212723e836276ccf39467f2e7
2021-07-23 10:30:38 -07:00

60 lines
2.1 KiB
Python

import torch
from collections import OrderedDict
from typing import Dict, Any, Tuple, List, Optional
from torch.fx.graph import (
Node,
)
from .quantization_types import Pattern
from ..qconfig import QConfigAny
# from .quantization_patterns import BinaryOpQuantizeHandler
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
QuantizeHandler = Any
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# pattern for conv bn fusion
DEFAULT_FUSION_PATTERNS = OrderedDict()
def register_fusion_pattern(pattern):
def insert(fn):
DEFAULT_FUSION_PATTERNS[pattern] = fn
return fn
return insert
def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
return DEFAULT_FUSION_PATTERNS
DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
# a map from pattern to activation_post_process(observer/fake_quant) consstructor for output activation
# e.g. pattern: torch.sigmoid,
# output_activation_post_process: default_affine_fixed_qparam_fake_quant
DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP = dict()
# Register pattern for both static quantization and qat
def register_quant_pattern(pattern, output_activation_post_process=None):
def insert(fn):
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
if output_activation_post_process is not None:
DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP[pattern] = output_activation_post_process
return fn
return insert
# Get patterns for both static quantization and qat
def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
return DEFAULT_QUANTIZATION_PATTERNS
# a map from pattern to output activation post process constructor
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
def get_default_output_activation_post_process_map() -> Dict[Pattern, torch.quantization.observer.ObserverBase]:
return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP
# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
# def __init__(...):
# ...
#