[inductor] add subsystem to pattern matcher (#163922)

Summary:
Running a toy example through `torch.compile(fullgraph=True, backend="inductor")` with default inductor config, I tried to see what passes are run in each of pre-grad, joint-graph, and post-grad phases by printing out the subsystem in `GraphTransformObserver`. However the subsystem showed up as None in a bunch of transforms that were run in each of those phases, so this PR adds some additional annotations.

Note that these annotations are probably not a complete set, since other transforms may run based on changes to the config that are not covered here.

Hopefully this doesn't change behavior. However, I did notice that bisecting relies on disabling various phases, which means that while before some passes would *not* be disabled (because their subsystem was `None`), now they would.

Test Plan: existing tests + manual test described in summary

Differential Revision: D83306676

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163922
Approved by: https://github.com/jansel
This commit is contained in:
Avik Chaudhuri 2025-09-28 03:15:23 +00:00 committed by PyTorch MergeBot
parent 5504a06e01
commit 3059b08012
5 changed files with 28 additions and 8 deletions

View File

@ -31,7 +31,7 @@ from ..pattern_matcher import (
KeywordArg,
Match,
MULTIPLE,
PatternMatcherPass,
PatternMatcherPass as PatternMatcherPassBase,
register_graph_pattern,
stable_topological_sort,
)
@ -39,6 +39,10 @@ from .decompose_mem_bound_mm import check_device
from .replace_random import replace_random_passes
PatternMatcherPass = functools.partial(
PatternMatcherPassBase, subsystem="joint_graph_passes"
)
log = logging.getLogger(__name__)
patterns = PatternMatcherPass()
aten = torch.ops.aten

View File

@ -42,7 +42,7 @@ from ..pattern_matcher import (
Match,
MultiOutputPattern,
MULTIPLE,
PatternMatcherPass,
PatternMatcherPass as PatternMatcherPassBase,
register_graph_pattern,
register_replacement,
stable_topological_sort,
@ -68,6 +68,10 @@ from .split_cat import POST_GRAD_PATTERNS
_T = TypeVar("_T")
_P = ParamSpec("_P")
PatternMatcherPass = functools.partial(
PatternMatcherPassBase, subsystem="post_grad_passes"
)
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import copy
import functools
import itertools
import logging
import types
@ -14,7 +15,9 @@ from torch.fx.experimental.optimization import (
matches_module_pattern,
replace_node_module,
)
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.fx.passes.graph_transform_observer import (
GraphTransformObserver as GraphTransformObserverBase,
)
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn import functional as F
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
@ -23,7 +26,7 @@ from .. import config
from ..fx_utils import matches_module_function_pattern
from ..pattern_matcher import (
init_once_fakemode,
PatternMatcherPass,
PatternMatcherPass as PatternMatcherPassBase,
stable_topological_sort,
)
from ..utils import is_cpu_device, pass_execution_and_save
@ -32,6 +35,13 @@ from .misc_patterns import numpy_compat_normalization
from .split_cat import PRE_GRAD_PATTERNS
PatternMatcherPass = functools.partial(
PatternMatcherPassBase, subsystem="pre_grad_passes"
)
GraphTransformObserver = functools.partial(
GraphTransformObserverBase, subsystem="pre_grad_passes"
)
log = logging.getLogger(__name__)
efficient_conv_bn_eval_pass = PatternMatcherPass(
@ -165,7 +175,7 @@ def lazy_init():
def _get_pass_name_func(p):
if isinstance(p, PatternMatcherPass):
if isinstance(p, PatternMatcherPassBase):
pass_name = p.pass_name
pass_func = p.apply
elif isinstance(p, types.FunctionType):

View File

@ -17,7 +17,7 @@ from ..virtualized import V
log = logging.getLogger(__name__)
patterns = PatternMatcherPass()
patterns = PatternMatcherPass(subsystem="joint_graph_passes")
aten = torch.ops.aten
@ -27,7 +27,7 @@ def replace_random_passes(gm: torch.fx.GraphModule):
return 0
count = patterns.apply(gm)
with GraphTransformObserver(gm, "fuse_seed_creation_pass"):
with GraphTransformObserver(gm, "fuse_seed_creation_pass", "joint_graph_passes"):
count += fuse_seed_creation_pass(gm.graph)
return count

View File

@ -1905,12 +1905,14 @@ class PatternMatcherPass:
def __init__(
self,
pass_name: Optional[str] = None,
subsystem: Optional[str] = None,
) -> None:
super().__init__()
self.patterns: defaultdict[
tuple[str, torch.fx.node.Target], list[PatternEntry]
] = defaultdict(list)
self.pass_name = pass_name
self.subsystem = subsystem
# For a particular generated pattern repr, store all of the str representations
# of the graph used to generate them. Because we ignore certain patterns
@ -1950,7 +1952,7 @@ class PatternMatcherPass:
nodes.append(graph.find_nodes(op="call_module", sort=False))
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
assert isinstance(gm, torch.fx.GraphModule)
with GraphTransformObserver(gm, pass_name):
with GraphTransformObserver(gm, pass_name, self.subsystem):
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
target = extract_target(node)
if node.op == "call_module":