mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Previously prepare_fx returns an ObservedGraphModule and convert_fx returns a QuantizedGraphModule, this is to preserve the attributes since torch.fx.GraphModule did not preserve them, after https://github.com/pytorch/pytorch/pull/92062 we are preserving `model.meta`, so we can store the attributes in model.meta now to preserve them. With this, we don't need to create a new type of GraphModule in these functions and can use GraphModule directly, this is useful for quantization in pytorch 2.0 flow, if other transformations are using GraphModule as well, the quantization passes will be composable with them Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestQuantizeFxModels python test/test_quantization.py TestQuantizePT2E Imported from OSS Differential Revision: D42979722 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94412 Approved by: https://github.com/vkuzo
162 lines
6.4 KiB
Python
162 lines
6.4 KiB
Python
from torch.fx import (
|
|
GraphModule,
|
|
Node,
|
|
map_arg
|
|
)
|
|
from torch.fx.graph import Graph
|
|
from .match_utils import (
|
|
_is_match,
|
|
MatchAllNode,
|
|
)
|
|
from .pattern_utils import (
|
|
_sorted_patterns_dict,
|
|
)
|
|
|
|
from ..backend_config import (
|
|
BackendConfig,
|
|
get_native_backend_config,
|
|
)
|
|
from ..backend_config.utils import (
|
|
get_fuser_method_mapping,
|
|
get_fusion_pattern_to_root_node_getter,
|
|
get_fusion_pattern_to_extra_inputs_getter,
|
|
)
|
|
|
|
from .custom_config import FuseCustomConfig
|
|
|
|
from .fuse_handler import (
|
|
_get_fusion_pattern_to_fuse_handler_cls,
|
|
FuseHandler,
|
|
)
|
|
|
|
from typing import Any, Callable, Dict, List, Tuple, Union
|
|
import warnings
|
|
|
|
from torch.ao.quantization.utils import Pattern, NodePattern
|
|
|
|
|
|
__all__ = [
|
|
"fuse",
|
|
# TODO: We should make this private in the future
|
|
# This is currently needed for test_public_bindings for some reason
|
|
"FuseHandler",
|
|
]
|
|
|
|
|
|
def fuse(
|
|
model: GraphModule,
|
|
is_qat: bool,
|
|
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> GraphModule:
|
|
if fuse_custom_config is None:
|
|
fuse_custom_config = FuseCustomConfig()
|
|
|
|
if isinstance(fuse_custom_config, Dict):
|
|
warnings.warn(
|
|
"Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
|
|
"in a future version. Please pass in a FuseCustomConfig instead.")
|
|
fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
|
|
|
|
if isinstance(backend_config, Dict):
|
|
warnings.warn(
|
|
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
|
|
"in a future version. Please pass in a BackendConfig instead.")
|
|
backend_config = BackendConfig.from_dict(backend_config)
|
|
|
|
named_modules = dict(model.named_modules())
|
|
|
|
if backend_config is None:
|
|
backend_config = get_native_backend_config()
|
|
|
|
fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config))
|
|
fuser_method_mapping = get_fuser_method_mapping(backend_config)
|
|
fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
|
|
fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config)
|
|
|
|
# find fusion
|
|
fusion_pairs = _find_matches(
|
|
model, model.graph, fusion_pattern_to_fuse_handler_cls)
|
|
# TODO: change this to inplace changes to graph, since we no longer construct
|
|
# new GraphModule anymore
|
|
fused_graph = Graph()
|
|
env: Dict[Any, Any] = {}
|
|
|
|
def load_arg(a):
|
|
return map_arg(a, lambda node: env[node.name])
|
|
|
|
def default_root_node_getter(node_pattern):
|
|
while not isinstance(node_pattern[-1], Node):
|
|
node_pattern = node_pattern[-1]
|
|
return node_pattern[-1]
|
|
|
|
for node in model.graph.nodes:
|
|
maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
|
|
fusion_pairs.get(node.name, (None, None, None, None, None))
|
|
# get the corresponding subpattern for the current node
|
|
if node_to_subpattern is not None:
|
|
node_subpattern = node_to_subpattern.get(node, None)
|
|
else:
|
|
node_subpattern = None
|
|
if maybe_last_node is node:
|
|
assert obj is not None
|
|
root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
|
|
root_node = root_node_getter(matched_node_pattern) # type: ignore[index]
|
|
extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
|
|
extra_inputs = []
|
|
if extra_inputs_getter is not None:
|
|
extra_inputs = extra_inputs_getter(matched_node_pattern)
|
|
# TODO: add validation that root_node is a module and has the same type
|
|
# as the root_module in the configuration
|
|
env[node.name] = obj.fuse(
|
|
load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern, # type: ignore[arg-type]
|
|
fuse_custom_config, fuser_method_mapping, is_qat)
|
|
elif maybe_last_node is None or node_subpattern is MatchAllNode:
|
|
env[node.name] = fused_graph.node_copy(node, load_arg)
|
|
# node matched in patterns and is not root is removed here
|
|
|
|
model = GraphModule(model, fused_graph)
|
|
return model
|
|
|
|
def _find_matches(
|
|
root: GraphModule,
|
|
graph: Graph,
|
|
pattern_to_fuse_handler_cls: Dict[Pattern, Callable],
|
|
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
|
|
modules = dict(root.named_modules())
|
|
# node name -> (root_node, match_value)
|
|
match_map : Dict[
|
|
str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {}
|
|
# a map from node to the matched subpattern
|
|
node_to_subpattern: Dict[Node, Any] = {}
|
|
|
|
# TODO: dedup with quantization matching function in match_utils.py
|
|
def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
|
|
if isinstance(pattern, tuple):
|
|
s, *args = pattern
|
|
current_node_pattern: List[Node] = []
|
|
apply_match(s, node, match, current_node_pattern, node_to_subpattern)
|
|
for subpattern, arg in zip(args, node.args):
|
|
apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern)
|
|
matched_node_pattern.append(tuple(current_node_pattern))
|
|
else:
|
|
# the first pattern matches will take precedence
|
|
if node.name not in match_map:
|
|
matched_node_pattern.append(node)
|
|
# MatchAllNode here is actually MatchAllInputNode which should not
|
|
# be added to match_map
|
|
if pattern is not MatchAllNode:
|
|
node_to_subpattern[node] = pattern
|
|
root_node, pattern, handler = match
|
|
match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern)
|
|
|
|
for node in reversed(graph.nodes):
|
|
if node.name not in match_map:
|
|
for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
|
|
matched_node_pattern: List[Node] = []
|
|
if _is_match(modules, node, pattern):
|
|
apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern)
|
|
break
|
|
|
|
return match_map
|