[Easy] Refactor post grad application of passes (#139293)

Refactors GraphTransformObserver to hook into the bisect manager pass application. And reworks post grad passes to use it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139293
Approved by: https://github.com/exclamaforte
ghstack dependencies: #139292
This commit is contained in:
eellison 2024-10-30 15:19:24 -07:00 committed by PyTorch MergeBot
parent 5075046db2
commit f93ebb2cf4
3 changed files with 82 additions and 49 deletions

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-decorators # mypy: allow-untyped-decorators
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import functools
import itertools import itertools
import logging import logging
import operator import operator
from collections import Counter, defaultdict from collections import Counter, defaultdict
from typing import Any, Callable, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
import torch import torch
import torch._inductor as inductor import torch._inductor as inductor
@ -17,7 +18,6 @@ from torch._inductor.virtualized import ops
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
from torch._utils_internal import upload_graph from torch._utils_internal import upload_graph
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from .. import config, ir, pattern_matcher from .. import config, ir, pattern_matcher
from ..codegen.common import BackendFeature, has_backend_feature from ..codegen.common import BackendFeature, has_backend_feature
@ -65,19 +65,6 @@ pass_patterns = [
] ]
def apply_pass(pass_fn: Callable[[], object], name: Optional[str] = None) -> None:
# TODO - we should just make this part of GraphTransformObserver
from torch._inductor.bisect_helper import BisectionManager
debug_info: Optional[Callable[[], str]] = None
if name is not None:
debug_info = lambda: name # noqa: E731
if BisectionManager.disable_subsystem("inductor", "post_grad_passes", debug_info):
return
pass_fn()
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
""" """
Passes that run on after grad. This is called once on the forwards Passes that run on after grad. This is called once on the forwards
@ -85,6 +72,11 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
The IR here has been normalized and functionalized. The IR here has been normalized and functionalized.
""" """
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="post_grad_passes",
)
if not torch._dynamo.config.skip_fsdp_hooks: if not torch._dynamo.config.skip_fsdp_hooks:
remove_fsdp2_unsharded_param_graph_input_usage(gm.graph) remove_fsdp2_unsharded_param_graph_input_usage(gm.graph)
@ -93,26 +85,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
gm.graph.eliminate_dead_code() gm.graph.eliminate_dead_code()
if is_inference and config.reorder_for_locality: if is_inference and config.reorder_for_locality:
apply_pass(lambda: reorder_for_locality(gm.graph), "reorder_for_locality") GraphTransformObserver(gm, "reorder_for_locality").apply_graph_pass(
reorder_for_locality
)
fake_tensor_updater = FakeTensorUpdater(gm.graph) fake_tensor_updater = FakeTensorUpdater(gm.graph)
if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass: if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass:
with GraphTransformObserver(gm, "post_grad_custom_pre_pass"): GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
apply_pass( post_grad_custom_pre_pass
lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass" )
)
if config.pattern_matcher: if config.pattern_matcher:
lazy_init() lazy_init()
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph) optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
apply_pass( GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
lambda: group_batch_fusion_passes(gm.graph, pre_grad=False), functools.partial(group_batch_fusion_passes, pre_grad=False)
"group_batch_fusion_passes",
) )
apply_pass(lambda: remove_noop_ops(gm.graph), "remove_noop_ops") GraphTransformObserver(gm, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
for i, patterns in enumerate(pass_patterns): for i, patterns in enumerate(pass_patterns):
apply_pass(lambda: patterns.apply(gm.graph), f"pass_pattern_{i}") # type: ignore[arg-type] GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass(
patterns.apply
)
for pass_name in config.post_grad_fusion_options: for pass_name in config.post_grad_fusion_options:
# skip all patterns for group batch fusions # skip all patterns for group batch fusions
if pass_name in POST_GRAD_FUSIONS: if pass_name in POST_GRAD_FUSIONS:
@ -121,7 +115,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
inductor_before_change = save_inductor_dict( inductor_before_change = save_inductor_dict(
[pattern_matcher_pass.pass_name] [pattern_matcher_pass.pass_name]
) )
apply_pass(lambda: pattern_matcher_pass.apply(gm.graph), pass_name) # type: ignore[arg-type] GraphTransformObserver(gm, pass_name).apply_graph_pass(
pattern_matcher_pass.apply
)
if not is_same_dict(counters["inductor"], inductor_before_change): if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[ optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_post_grad" f"{pattern_matcher_pass.pass_name}_post_grad"
@ -133,37 +129,37 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
micro_pipeline_tp_pass(gm.graph) micro_pipeline_tp_pass(gm.graph)
if config._fuse_ddp_communication: if config._fuse_ddp_communication:
apply_pass( GraphTransformObserver(gm, "fuse_ddp_communication").apply_graph_pass(
lambda: fuse_ddp_communication( lambda graph: fuse_ddp_communication(
gm.graph, graph,
config._fuse_ddp_communication_passes, config._fuse_ddp_communication_passes,
config._fuse_ddp_bucket_size, config._fuse_ddp_bucket_size,
), )
"fuse_ddp_communication",
) )
if post_grad_custom_post_pass := config.post_grad_custom_post_pass: if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
with GraphTransformObserver(gm, "post_grad_custom_post_pass"): GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass(
apply_pass( post_grad_custom_post_pass
lambda: post_grad_custom_post_pass(gm.graph), )
"post_grad_custom_post_pass",
)
apply_pass(lambda: stable_topological_sort(gm.graph), "stable_sort") GraphTransformObserver(gm, "stable_sort").apply_graph_pass(stable_topological_sort)
apply_pass(lambda: move_constructors_to_gpu(gm.graph), "move_constructors_to_cuda") GraphTransformObserver(gm, "move_constructors_to_cuda").apply_graph_pass(
move_constructors_to_gpu
)
fake_tensor_updater.incremental_update() fake_tensor_updater.incremental_update()
# Keep these last, since they introduces mutation. Look at # Keep these last, since they introduces mutation. Look at
# ./fx_passes/README.md for a discussion of mutation invariants. # ./fx_passes/README.md for a discussion of mutation invariants.
apply_pass(lambda: reinplace_inplaceable_ops(gm.graph), "reinplace_inplaceable_ops") GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
apply_pass( reinplace_inplaceable_ops
lambda: decompose_auto_functionalized(gm.graph), "decompose_auto_functionalized"
) )
GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass(
apply_pass( decompose_auto_functionalized
lambda: comms.reinplace_fsdp_all_gather(gm.graph), "reinplace_fsdp_all_gather" )
GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass(
comms.reinplace_fsdp_all_gather
) )
gm.recompile() gm.recompile()

View File

@ -1717,7 +1717,7 @@ class PatternMatcherPass:
def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
return self.patterns[item] return self.patterns[item]
def apply(self, gm: torch.fx.GraphModule) -> int: def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
if not self.patterns: if not self.patterns:
return 0 return 0
if isinstance(gm, torch.fx.GraphModule): if isinstance(gm, torch.fx.GraphModule):
@ -1745,6 +1745,7 @@ class PatternMatcherPass:
if has_call_module: if has_call_module:
nodes.append(graph.find_nodes(op="call_module", sort=False)) nodes.append(graph.find_nodes(op="call_module", sort=False))
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
assert isinstance(gm, torch.fx.GraphModule)
with GraphTransformObserver(gm, pass_name): with GraphTransformObserver(gm, pass_name):
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
target = extract_target(node) target = extract_target(node)

View File

@ -1,10 +1,15 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import os import os
from typing import Optional from typing import Callable, Optional, TypeVar
from torch.fx import Graph
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
T = TypeVar("T")
from .graph_drawer import FxGraphDrawer from .graph_drawer import FxGraphDrawer
@ -16,12 +21,20 @@ class GraphTransformObserver:
__pass_count = 0 __pass_count = 0
def __init__( def __init__(
self, gm: GraphModule, passname: str, *, log_url: Optional[str] = None self,
gm: GraphModule,
passname: str,
subsystem: Optional[str] = None,
log_url: Optional[str] = None,
): ):
""" """
log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
""" """
self.gm = gm
self.passname = passname
self.subsystem = subsystem
# If log_url is None, we don't log anything # If log_url is None, we don't log anything
if log_url is None: if log_url is None:
from torch._inductor.config import trace from torch._inductor.config import trace
@ -32,8 +45,6 @@ class GraphTransformObserver:
if self.log_url is None: if self.log_url is None:
return return
GraphTransformObserver.__pass_count += 1 GraphTransformObserver.__pass_count += 1
self.gm = gm
self.passname = passname
self.input_dot_graph = FxGraphDrawer( self.input_dot_graph = FxGraphDrawer(
self.gm, self.gm,
@ -46,6 +57,31 @@ class GraphTransformObserver:
def get_current_pass_count(cls): def get_current_pass_count(cls):
return cls.__pass_count return cls.__pass_count
def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]:
with self:
if not self._check_disable_pass():
return pass_fn(self.gm)
return None
def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]:
with self:
if not self._check_disable_pass():
return pass_fn(self.gm.graph)
return None
def _check_disable_pass(self):
if self.subsystem is None:
return False
debug_info = lambda: self.passname # noqa: E731
from torch._inductor.bisect_helper import BisectionManager
return BisectionManager.disable_subsystem(
"inductor", self.subsystem, debug_info
)
def __enter__(self): def __enter__(self):
if self.log_url is None or self.gm is None: if self.log_url is None or self.gm is None:
return self return self