diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 3fd9332ea1e..e2e6f1fbdc5 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs +import functools import itertools import logging import operator from collections import Counter, defaultdict -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set import torch import torch._inductor as inductor @@ -17,7 +18,6 @@ from torch._inductor.virtualized import ops from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype from torch._utils_internal import upload_graph from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq -from torch.fx.passes.graph_transform_observer import GraphTransformObserver from .. import config, ir, pattern_matcher from ..codegen.common import BackendFeature, has_backend_feature @@ -65,19 +65,6 @@ pass_patterns = [ ] -def apply_pass(pass_fn: Callable[[], object], name: Optional[str] = None) -> None: - # TODO - we should just make this part of GraphTransformObserver - from torch._inductor.bisect_helper import BisectionManager - - debug_info: Optional[Callable[[], str]] = None - if name is not None: - debug_info = lambda: name # noqa: E731 - - if BisectionManager.disable_subsystem("inductor", "post_grad_passes", debug_info): - return - pass_fn() - - def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): """ Passes that run on after grad. This is called once on the forwards @@ -85,6 +72,11 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): The IR here has been normalized and functionalized. """ + GraphTransformObserver = functools.partial( + torch.fx.passes.graph_transform_observer.GraphTransformObserver, + subsystem="post_grad_passes", + ) + if not torch._dynamo.config.skip_fsdp_hooks: remove_fsdp2_unsharded_param_graph_input_usage(gm.graph) @@ -93,26 +85,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): gm.graph.eliminate_dead_code() if is_inference and config.reorder_for_locality: - apply_pass(lambda: reorder_for_locality(gm.graph), "reorder_for_locality") + GraphTransformObserver(gm, "reorder_for_locality").apply_graph_pass( + reorder_for_locality + ) fake_tensor_updater = FakeTensorUpdater(gm.graph) if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass: - with GraphTransformObserver(gm, "post_grad_custom_pre_pass"): - apply_pass( - lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass" - ) + GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass( + post_grad_custom_pre_pass + ) if config.pattern_matcher: lazy_init() optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph) - apply_pass( - lambda: group_batch_fusion_passes(gm.graph, pre_grad=False), - "group_batch_fusion_passes", + GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass( + functools.partial(group_batch_fusion_passes, pre_grad=False) ) - apply_pass(lambda: remove_noop_ops(gm.graph), "remove_noop_ops") + GraphTransformObserver(gm, "remove_noop_ops").apply_graph_pass(remove_noop_ops) for i, patterns in enumerate(pass_patterns): - apply_pass(lambda: patterns.apply(gm.graph), f"pass_pattern_{i}") # type: ignore[arg-type] + GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass( + patterns.apply + ) for pass_name in config.post_grad_fusion_options: # skip all patterns for group batch fusions if pass_name in POST_GRAD_FUSIONS: @@ -121,7 +115,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): inductor_before_change = save_inductor_dict( [pattern_matcher_pass.pass_name] ) - apply_pass(lambda: pattern_matcher_pass.apply(gm.graph), pass_name) # type: ignore[arg-type] + GraphTransformObserver(gm, pass_name).apply_graph_pass( + pattern_matcher_pass.apply + ) if not is_same_dict(counters["inductor"], inductor_before_change): optimus_scuba_log[ f"{pattern_matcher_pass.pass_name}_post_grad" @@ -133,37 +129,37 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): micro_pipeline_tp_pass(gm.graph) if config._fuse_ddp_communication: - apply_pass( - lambda: fuse_ddp_communication( - gm.graph, + GraphTransformObserver(gm, "fuse_ddp_communication").apply_graph_pass( + lambda graph: fuse_ddp_communication( + graph, config._fuse_ddp_communication_passes, config._fuse_ddp_bucket_size, - ), - "fuse_ddp_communication", + ) ) if post_grad_custom_post_pass := config.post_grad_custom_post_pass: - with GraphTransformObserver(gm, "post_grad_custom_post_pass"): - apply_pass( - lambda: post_grad_custom_post_pass(gm.graph), - "post_grad_custom_post_pass", - ) + GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass( + post_grad_custom_post_pass + ) - apply_pass(lambda: stable_topological_sort(gm.graph), "stable_sort") + GraphTransformObserver(gm, "stable_sort").apply_graph_pass(stable_topological_sort) - apply_pass(lambda: move_constructors_to_gpu(gm.graph), "move_constructors_to_cuda") + GraphTransformObserver(gm, "move_constructors_to_cuda").apply_graph_pass( + move_constructors_to_gpu + ) fake_tensor_updater.incremental_update() # Keep these last, since they introduces mutation. Look at # ./fx_passes/README.md for a discussion of mutation invariants. - apply_pass(lambda: reinplace_inplaceable_ops(gm.graph), "reinplace_inplaceable_ops") - apply_pass( - lambda: decompose_auto_functionalized(gm.graph), "decompose_auto_functionalized" + GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( + reinplace_inplaceable_ops ) - - apply_pass( - lambda: comms.reinplace_fsdp_all_gather(gm.graph), "reinplace_fsdp_all_gather" + GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass( + decompose_auto_functionalized + ) + GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass( + comms.reinplace_fsdp_all_gather ) gm.recompile() diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 7c6e02dd178..436ac963c2e 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1717,7 +1717,7 @@ class PatternMatcherPass: def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: return self.patterns[item] - def apply(self, gm: torch.fx.GraphModule) -> int: + def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int: if not self.patterns: return 0 if isinstance(gm, torch.fx.GraphModule): @@ -1745,6 +1745,7 @@ class PatternMatcherPass: if has_call_module: nodes.append(graph.find_nodes(op="call_module", sort=False)) pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" + assert isinstance(gm, torch.fx.GraphModule) with GraphTransformObserver(gm, pass_name): for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): target = extract_target(node) diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index a96e25e799a..7c2b0b1940a 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -1,10 +1,15 @@ # mypy: allow-untyped-defs import os -from typing import Optional +from typing import Callable, Optional, TypeVar +from torch.fx import Graph from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule + +T = TypeVar("T") + + from .graph_drawer import FxGraphDrawer @@ -16,12 +21,20 @@ class GraphTransformObserver: __pass_count = 0 def __init__( - self, gm: GraphModule, passname: str, *, log_url: Optional[str] = None + self, + gm: GraphModule, + passname: str, + subsystem: Optional[str] = None, + log_url: Optional[str] = None, ): """ log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified """ + self.gm = gm + self.passname = passname + self.subsystem = subsystem + # If log_url is None, we don't log anything if log_url is None: from torch._inductor.config import trace @@ -32,8 +45,6 @@ class GraphTransformObserver: if self.log_url is None: return GraphTransformObserver.__pass_count += 1 - self.gm = gm - self.passname = passname self.input_dot_graph = FxGraphDrawer( self.gm, @@ -46,6 +57,31 @@ class GraphTransformObserver: def get_current_pass_count(cls): return cls.__pass_count + def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]: + with self: + if not self._check_disable_pass(): + return pass_fn(self.gm) + + return None + + def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]: + with self: + if not self._check_disable_pass(): + return pass_fn(self.gm.graph) + + return None + + def _check_disable_pass(self): + if self.subsystem is None: + return False + + debug_info = lambda: self.passname # noqa: E731 + from torch._inductor.bisect_helper import BisectionManager + + return BisectionManager.disable_subsystem( + "inductor", self.subsystem, debug_info + ) + def __enter__(self): if self.log_url is None or self.gm is None: return self