[inductor][invoke_subgraph] Run joint graph passes for inference (#152062)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152062
Approved by: https://github.com/eellison
ghstack dependencies: #151409, #151633, #151477, #151957, #151961
This commit is contained in:
Animesh Jain 2025-04-23 22:26:58 -07:00 committed by PyTorch MergeBot
parent 99b6c426a9
commit ddff3d4f6b

View File

@ -454,7 +454,9 @@ def _recursive_pre_grad_passes(
return pre_grad_passes(gm, example_inputs, add_passes, remove_passes)
def _recursive_joint_graph_passes(gm: GraphModule) -> None:
def _recursive_joint_graph_passes(
gm: GraphModule, skip_invoke_subgraph: bool = False
) -> None:
with dynamo_timed(
"_recursive_joint_graph_passes",
log_pt2_compile_event=True,
@ -466,9 +468,9 @@ def _recursive_joint_graph_passes(gm: GraphModule) -> None:
# AOTAutograd has access to partition_fn, which internally calls the
# `_recursive_joint_graph_passes` for the subgraph. So, skip recursing
# skip_invoke_subgraph.
for subgraph_name in _get_subgraph_names(gm, skip_invoke_subgraph=True):
for subgraph_name in _get_subgraph_names(gm, skip_invoke_subgraph):
subgraph = getattr(gm, subgraph_name)
_recursive_joint_graph_passes(subgraph)
_recursive_joint_graph_passes(subgraph, skip_invoke_subgraph)
joint_graph_passes(gm)
@ -2153,7 +2155,10 @@ def compile_fx(
) -> tuple[GraphModule, GraphModule]:
cuda_context = get_cuda_device_context(gm)
with cuda_context:
_recursive_joint_graph_passes(gm)
# We can skip the invoke_subgraph because the
# entire_partition_fn is called recursively for invoke_subgraph
# in partitioning.
_recursive_joint_graph_passes(gm, skip_invoke_subgraph=True)
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
"static_lifetime_input_indices", None