mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] add subsystem to pattern matcher (#163922)
Summary: Running a toy example through `torch.compile(fullgraph=True, backend="inductor")` with default inductor config, I tried to see what passes are run in each of pre-grad, joint-graph, and post-grad phases by printing out the subsystem in `GraphTransformObserver`. However the subsystem showed up as None in a bunch of transforms that were run in each of those phases, so this PR adds some additional annotations. Note that these annotations are probably not a complete set, since other transforms may run based on changes to the config that are not covered here. Hopefully this doesn't change behavior. However, I did notice that bisecting relies on disabling various phases, which means that while before some passes would *not* be disabled (because their subsystem was `None`), now they would. Test Plan: existing tests + manual test described in summary Differential Revision: D83306676 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163922 Approved by: https://github.com/jansel
This commit is contained in:
parent
5504a06e01
commit
3059b08012
|
|
@ -31,7 +31,7 @@ from ..pattern_matcher import (
|
|||
KeywordArg,
|
||||
Match,
|
||||
MULTIPLE,
|
||||
PatternMatcherPass,
|
||||
PatternMatcherPass as PatternMatcherPassBase,
|
||||
register_graph_pattern,
|
||||
stable_topological_sort,
|
||||
)
|
||||
|
|
@ -39,6 +39,10 @@ from .decompose_mem_bound_mm import check_device
|
|||
from .replace_random import replace_random_passes
|
||||
|
||||
|
||||
PatternMatcherPass = functools.partial(
|
||||
PatternMatcherPassBase, subsystem="joint_graph_passes"
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
patterns = PatternMatcherPass()
|
||||
aten = torch.ops.aten
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ from ..pattern_matcher import (
|
|||
Match,
|
||||
MultiOutputPattern,
|
||||
MULTIPLE,
|
||||
PatternMatcherPass,
|
||||
PatternMatcherPass as PatternMatcherPassBase,
|
||||
register_graph_pattern,
|
||||
register_replacement,
|
||||
stable_topological_sort,
|
||||
|
|
@ -68,6 +68,10 @@ from .split_cat import POST_GRAD_PATTERNS
|
|||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
PatternMatcherPass = functools.partial(
|
||||
PatternMatcherPassBase, subsystem="post_grad_passes"
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import types
|
||||
|
|
@ -14,7 +15,9 @@ from torch.fx.experimental.optimization import (
|
|||
matches_module_pattern,
|
||||
replace_node_module,
|
||||
)
|
||||
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||
from torch.fx.passes.graph_transform_observer import (
|
||||
GraphTransformObserver as GraphTransformObserverBase,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
|
||||
|
|
@ -23,7 +26,7 @@ from .. import config
|
|||
from ..fx_utils import matches_module_function_pattern
|
||||
from ..pattern_matcher import (
|
||||
init_once_fakemode,
|
||||
PatternMatcherPass,
|
||||
PatternMatcherPass as PatternMatcherPassBase,
|
||||
stable_topological_sort,
|
||||
)
|
||||
from ..utils import is_cpu_device, pass_execution_and_save
|
||||
|
|
@ -32,6 +35,13 @@ from .misc_patterns import numpy_compat_normalization
|
|||
from .split_cat import PRE_GRAD_PATTERNS
|
||||
|
||||
|
||||
PatternMatcherPass = functools.partial(
|
||||
PatternMatcherPassBase, subsystem="pre_grad_passes"
|
||||
)
|
||||
GraphTransformObserver = functools.partial(
|
||||
GraphTransformObserverBase, subsystem="pre_grad_passes"
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
efficient_conv_bn_eval_pass = PatternMatcherPass(
|
||||
|
|
@ -165,7 +175,7 @@ def lazy_init():
|
|||
|
||||
|
||||
def _get_pass_name_func(p):
|
||||
if isinstance(p, PatternMatcherPass):
|
||||
if isinstance(p, PatternMatcherPassBase):
|
||||
pass_name = p.pass_name
|
||||
pass_func = p.apply
|
||||
elif isinstance(p, types.FunctionType):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from ..virtualized import V
|
|||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
patterns = PatternMatcherPass()
|
||||
patterns = PatternMatcherPass(subsystem="joint_graph_passes")
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ def replace_random_passes(gm: torch.fx.GraphModule):
|
|||
return 0
|
||||
|
||||
count = patterns.apply(gm)
|
||||
with GraphTransformObserver(gm, "fuse_seed_creation_pass"):
|
||||
with GraphTransformObserver(gm, "fuse_seed_creation_pass", "joint_graph_passes"):
|
||||
count += fuse_seed_creation_pass(gm.graph)
|
||||
|
||||
return count
|
||||
|
|
|
|||
|
|
@ -1905,12 +1905,14 @@ class PatternMatcherPass:
|
|||
def __init__(
|
||||
self,
|
||||
pass_name: Optional[str] = None,
|
||||
subsystem: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patterns: defaultdict[
|
||||
tuple[str, torch.fx.node.Target], list[PatternEntry]
|
||||
] = defaultdict(list)
|
||||
self.pass_name = pass_name
|
||||
self.subsystem = subsystem
|
||||
|
||||
# For a particular generated pattern repr, store all of the str representations
|
||||
# of the graph used to generate them. Because we ignore certain patterns
|
||||
|
|
@ -1950,7 +1952,7 @@ class PatternMatcherPass:
|
|||
nodes.append(graph.find_nodes(op="call_module", sort=False))
|
||||
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
|
||||
assert isinstance(gm, torch.fx.GraphModule)
|
||||
with GraphTransformObserver(gm, pass_name):
|
||||
with GraphTransformObserver(gm, pass_name, self.subsystem):
|
||||
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
|
||||
target = extract_target(node)
|
||||
if node.op == "call_module":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user