mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70009
Currently we rely on module.training to decide whether we'll do a qat fusion or ptq fusion, this is
not ideal since training flag has nothing to do with quantization, this PR introduces an extra flag `is_qat`
to control this
Note: currently we still has the constraint that when `is_qat` is True, the modules must be in training mode, we
can relax this constraint later
Test Plan:
```
python test/test_quantization.py TestFuseFx
python test/test_quantization.py TestFusion
```
Imported from OSS
**Static Docs Preview: classyvision**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D33178977/V36/classyvision/)|
|**Modified Pages**|
Reviewed By: mruberry
Differential Revision: D33178977
fbshipit-source-id: 0c1499c45526971140d9ad58e2994d1edf5ad770
(cherry picked from commit 2d51f9fb28)
122 lines
5.8 KiB
Python
122 lines
5.8 KiB
Python
import torch
|
|
from torch.fx.graph import Node
|
|
from .pattern_utils import (
|
|
register_fusion_pattern,
|
|
)
|
|
from ..utils import _parent_name
|
|
from .quantization_types import QuantizerCls, NodePattern, Pattern
|
|
from ..fuser_method_mappings import get_fuser_method_new
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Callable, Dict, Optional, Union, List
|
|
from .match_utils import MatchAllNode
|
|
|
|
# ----------------------------
|
|
# Fusion Pattern Registrations
|
|
# ----------------------------
|
|
|
|
# Base Pattern Handler
|
|
class FuseHandler(ABC):
|
|
""" Base handler class for the fusion patterns
|
|
"""
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fuse(self,
|
|
quantizer: QuantizerCls,
|
|
load_arg: Callable,
|
|
root_node: Node,
|
|
matched_node_pattern: NodePattern,
|
|
fuse_custom_config_dict: Dict[str, Any],
|
|
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
|
|
is_qat: bool) -> Node:
|
|
pass
|
|
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm2d))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
|
|
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
|
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
|
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
|
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
|
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
|
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
|
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.ConvTranspose1d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.ConvTranspose3d))
|
|
class DefaultFuseHandler(FuseHandler):
|
|
def __init__(
|
|
self,
|
|
quantizer: QuantizerCls,
|
|
node: Node):
|
|
super().__init__(quantizer, node)
|
|
|
|
def fuse(self,
|
|
quantizer: QuantizerCls,
|
|
load_arg: Callable,
|
|
root_node: Node,
|
|
matched_node_pattern: NodePattern,
|
|
fuse_custom_config_dict: Dict[str, Any],
|
|
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
|
|
is_qat: bool) -> Node:
|
|
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
|
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
|
|
root_module = quantizer.modules[root_node.target]
|
|
assert len(additional_fuser_method_mapping) == 0, "Fusion implementation is "
|
|
"undergoing changes, additoinal_fuser_method_mapping is not supported currently."
|
|
def get_modules(pattern, modules):
|
|
""" Given a node pattern, extract the corresponding modules
|
|
e.g. input: (relu_node, (bn_node, conv_node))
|
|
output: (relu_module, (bn_module, conv_module))
|
|
"""
|
|
if isinstance(pattern, (tuple, list)):
|
|
n, *args = pattern
|
|
get_modules(n, modules)
|
|
arg_modules: List[torch.nn.Module] = []
|
|
for a in args:
|
|
get_modules(a, arg_modules)
|
|
arg_modules = tuple(arg_modules) if len(arg_modules) > 1 else arg_modules[0] # type: ignore[assignment]
|
|
modules.append(arg_modules)
|
|
else:
|
|
n = pattern
|
|
if n.op == "call_module":
|
|
modules.append(quantizer.modules[n.target])
|
|
elif n.op == "call_function" and n.target == torch.nn.functional.relu:
|
|
relu = torch.nn.ReLU()
|
|
relu.training = root_module.training
|
|
modules.append(relu)
|
|
else:
|
|
modules.append(MatchAllNode)
|
|
return tuple(modules)
|
|
|
|
# since relu can be used multiple times, we'll need to create a relu module for each match
|
|
matched_modules = get_modules(matched_node_pattern, [])
|
|
|
|
def get_matched_types(m):
|
|
if isinstance(m, tuple):
|
|
return tuple(map(get_matched_types, m))
|
|
return type(m)
|
|
|
|
matched_module_types = get_matched_types(matched_modules)
|
|
module_parent_name, module_name = _parent_name(root_node.target)
|
|
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
|
|
# TODO: change the signature for fuser_method to take matched module patterns
|
|
# as input
|
|
fused_module = fuser_method(is_qat, *matched_modules)
|
|
# TODO: maybe add a pass to cleanup bn modules?
|
|
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
|
|
return quantizer.fused_graph.node_copy(root_node, load_arg)
|