[inductor] Lift fw_compiler and bw_compiler as toplevel functions. (#161762)

This is a no-op refactor to compiler_fx which lifts the logic of fw_compiler and bw_compiler to toplevel, so that they can be reused in a different stack (e.g. precompile).

Differential Revision: [D81292968](https://our.internmc.facebook.com/intern/diff/D81292968/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161762
Approved by: https://github.com/angelayi, https://github.com/yushangdi
This commit is contained in:
zhxchen17 2025-08-29 10:56:04 -07:00 committed by PyTorch MergeBot
parent 05eeb29976
commit eb78757708

View File

@ -154,6 +154,8 @@ else:
from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
if TYPE_CHECKING: if TYPE_CHECKING:
import types
from torch._functorch._aot_autograd.schemas import ( from torch._functorch._aot_autograd.schemas import (
FQN, FQN,
GraphInputName, GraphInputName,
@ -2120,6 +2122,248 @@ def partition_fn(
) )
def get_num_model_outputs(model: GraphModule) -> int:
model_outputs_node = output_node(model)
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
return len(model_outputs)
@dataclass(frozen=True)
class CompilerConfigExtra:
cudagraphs: BoxedBool
graph_id: int
forward_device: BoxedDeviceIndex
def create_compiler_config_extra(config: types.ModuleType) -> CompilerConfigExtra:
# Although cudagraphs may have been enabled via config, various
# conditions (which are tested within the bowels of Inductor) may
# force cudagraphs to be disabled. This mutable box lets us retrieve
# the final determination if cudagraphs actually can be used or not.
cudagraphs = BoxedBool(config.triton.cudagraphs)
# TODO: The modern style is to use CompileId from TracingContext to
# identify Inductor compilation. However, this CompileId cannot
# uniquely identify multiple Inductor compilations that arise from
# DDPOptimizer
graph_id = next(_graph_counter)
# See [Backward Generation Handling]
forward_device = BoxedDeviceIndex(None)
return CompilerConfigExtra(
cudagraphs=cudagraphs,
graph_id=graph_id,
forward_device=forward_device,
)
def compile_fx_forward(
gm: GraphModule,
example_inputs: Sequence[InputType],
num_orig_model_outputs: int,
num_example_inputs: int,
compiler_config_extra: CompilerConfigExtra,
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
is_inference: bool = False,
) -> OutputCode:
"""
Compile the forward graph of the given graph module.
Args:
gm: The graph module to compile.
example_inputs: The example inputs to use for compilation.
num_orig_model_outputs: The number of model outputs from the original dynamo graph.
num_example_inputs: The number of example inputs from the original dynamo graph.
compiler_config_extra: Extra configuration for the compiler.
inner_compile: The inner compile function to use.
is_inference: Whether this is an inference graph.
"""
if is_inference:
# partition_fn won't be called
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_joint_graph",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
_recursive_joint_graph_passes(gm)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "after_joint_graph",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
fixed = torch._inductor.utils.num_fw_fixed_arguments(
num_example_inputs, len(example_inputs)
)
model_outputs_node = output_node(gm)
if config.keep_output_stride:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
num_model_outputs = len(model_outputs)
context = torch._guards.TracingContext.try_get()
# See Note [User Outputs in the inductor graph]
if context is not None and context.fw_metadata and not is_inference:
original_output_start_index = (
context.fw_metadata.num_mutated_inp_runtime_indices
)
else:
original_output_start_index = 0
assert num_orig_model_outputs <= num_model_outputs
# Note [User Outputs in the inductor graph]
# We makes the following assumption
# For inference
# len(orig_model_outputs) == len(model_outputs)
# For training
# len(orig_model_outputs) <= len(model_outputs)
# During training, most of the time the model_outputs starts with
# original module's outputs followed by saved activations.
# But this can be not true if the model have inplace updated tensors.
# AOTAutograd will make those tensors being returned before the original
# module's output.
# To make things safe, we'll use original_output_start_index field
# set by AOTAutograd to decide where the original module outputs start.
orig_output_end_idx = original_output_start_index + num_orig_model_outputs
# Sanity check: we are about to splice out the "user" outputs from the full set
# of "graph" outputs. Make sure we're within bounds.
assert orig_output_end_idx <= num_model_outputs
model_outputs_node.meta["user_visible_output_idxs"] = [
idx
for idx in range(original_output_start_index, orig_output_end_idx)
if isinstance(model_outputs[idx], torch.fx.Node)
]
else:
model_outputs_node.meta["user_visible_output_idxs"] = []
# We also mark the invoke_subgraph outputs as user_visible to
# force the outputs of invoke_subgraph subgraph to follow the
# original strides
_recursive_record_user_visible_output_idxs(gm)
return inner_compile(
gm,
example_inputs,
static_input_idxs=get_static_input_idxs(fixed),
cudagraphs=compiler_config_extra.cudagraphs,
graph_id=compiler_config_extra.graph_id,
is_inference=is_inference,
boxed_forward_device_index=compiler_config_extra.forward_device,
)
def compile_fx_backward(
gm: GraphModule,
example_inputs: Sequence[InputType],
compiler_config_extra: CompilerConfigExtra,
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
) -> OutputCode:
"""
Compile the backward graph of the given graph module.
Args:
gm: The graph module to compile.
example_inputs: The example inputs to use for compilation.
compiler_config_extra: Extra configuration for the compiler.
inner_compile: The inner compile function to use.
"""
from torch._dynamo.convert_frame import compile_lock
with compile_lock:
model_outputs_node = output_node(gm)
if config.bw_outputs_user_visible:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
model_outputs_node.meta["user_visible_output_idxs"] = [
idx
for idx, n in enumerate(model_outputs)
if isinstance(n, torch.fx.Node)
]
else:
model_outputs_node.meta["user_visible_output_idxs"] = []
fixed = count_tangents(gm)
with (
config.patch(get_cpp_wrapper_config())
if config.cpp_wrapper
else contextlib.nullcontext()
):
return inner_compile(
gm,
example_inputs,
static_input_idxs=list(range(fixed)),
cudagraphs=compiler_config_extra.cudagraphs,
is_backward=True,
graph_id=compiler_config_extra.graph_id,
boxed_forward_device_index=compiler_config_extra.forward_device,
)
def run_pre_grad_passes(
model_: GraphModule, example_inputs_: Sequence[InputType]
) -> GraphModule:
# "before_pre_grad_graph" is used in inductor provenance
# tracking highlighter front-end.
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_pre_grad_graph",
"encoding": "string",
},
payload_fn=lambda: model_.print_readable(
print_output=False, include_stride=True, include_device=True
)
+ f"\n\n # graph id: {id(model_.graph)}",
)
pre_grad_graphs_log.debug(
"%s",
lazy_format_graph_code(
"BEFORE PRE GRAD",
model_,
include_stride=True,
include_device=True,
colored=True,
),
)
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
if config.trace.provenance_tracking_level == 1:
for node in model_.graph.nodes:
if node.stack_trace:
torch._inductor.debug._inductor_pre_grad_node_stack_trace[node.name] = (
node.stack_trace
)
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "after_pre_grad_graph",
"encoding": "string",
},
payload_fn=lambda: model_.print_readable(
print_output=False, include_stride=True, include_device=True
)
+ f"\n\n # graph id: {id(model_.graph)}",
)
return model_
def compile_fx( def compile_fx(
model_: GraphModule, model_: GraphModule,
example_inputs_: Sequence[InputType], example_inputs_: Sequence[InputType],
@ -2264,50 +2508,7 @@ def compile_fx(
# having AOTAutograd trace it. # having AOTAutograd trace it.
# TODO: Get rid of this? # TODO: Get rid of this?
if isinstance(model_, GraphModule): if isinstance(model_, GraphModule):
# "before_pre_grad_graph" is used in inductor provenance model_ = run_pre_grad_passes(model_, example_inputs_)
# tracking highlighter front-end.
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_pre_grad_graph",
"encoding": "string",
},
payload_fn=lambda: model_.print_readable(
print_output=False, include_stride=True, include_device=True
)
+ f"\n\n # graph id: {id(model_.graph)}",
)
pre_grad_graphs_log.debug(
"%s",
lazy_format_graph_code(
"BEFORE PRE GRAD",
model_,
include_stride=True,
include_device=True,
colored=True,
),
)
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
if config.trace.provenance_tracking_level == 1:
for node in model_.graph.nodes:
if node.stack_trace:
torch._inductor.debug._inductor_pre_grad_node_stack_trace[
node.name
] = node.stack_trace
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "after_pre_grad_graph",
"encoding": "string",
},
payload_fn=lambda: model_.print_readable(
print_output=False, include_stride=True, include_device=True
)
+ f"\n\n # graph id: {id(model_.graph)}",
)
# TODO: Move this before recursive pre-grad passes # TODO: Move this before recursive pre-grad passes
# NB: This short circuit never occurs for Dynamo produced graphs # NB: This short circuit never occurs for Dynamo produced graphs
@ -2323,20 +2524,7 @@ def compile_fx(
num_example_inputs = len(example_inputs_) num_example_inputs = len(example_inputs_)
# Although cudagraphs may have been enabled via config, various compiler_config_extra = create_compiler_config_extra(config)
# conditions (which are tested within the bowels of Inductor) may
# force cudagraphs to be disabled. This mutable box lets us retrieve
# the final determination if cudagraphs actually can be used or not.
cudagraphs = BoxedBool(config.triton.cudagraphs)
# See [Backward Generation Handling]
forward_device = BoxedDeviceIndex(None)
# TODO: The modern style is to use CompileId from TracingContext to
# identify Inductor compilation. However, this CompileId cannot
# uniquely identify multiple Inductor compilations that arise from
# DDPOptimizer
graph_id = next(_graph_counter)
decompositions = ( decompositions = (
decompositions if decompositions is not None else select_decomp_table() decompositions if decompositions is not None else select_decomp_table()
@ -2348,105 +2536,18 @@ def compile_fx(
is_inference: bool, is_inference: bool,
) -> OutputCode: ) -> OutputCode:
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"): with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
if is_inference: if isinstance(model_, GraphModule):
# partition_fn won't be called num_orig_model_outputs = get_num_model_outputs(model_)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_joint_graph",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
_recursive_joint_graph_passes(gm)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "after_joint_graph",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
fixed = torch._inductor.utils.num_fw_fixed_arguments(
num_example_inputs, len(example_inputs)
)
model_outputs_node = output_node(gm)
if config.keep_output_stride:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
num_model_outputs = len(model_outputs)
context = torch._guards.TracingContext.try_get()
# See Note [User Outputs in the inductor graph]
if context is not None and context.fw_metadata and not is_inference:
original_output_start_index = (
context.fw_metadata.num_mutated_inp_runtime_indices
)
else:
original_output_start_index = 0
if isinstance(model_, GraphModule):
*_, orig_model_outputs_node = model_.graph.nodes
assert orig_model_outputs_node.op == "output"
orig_model_outputs, _ = pytree.tree_flatten(
orig_model_outputs_node.args
)
num_orig_model_outputs = len(orig_model_outputs)
else:
num_orig_model_outputs = num_model_outputs
assert num_orig_model_outputs <= num_model_outputs
# Note [User Outputs in the inductor graph]
# We makes the following assumption
# For inference
# len(orig_model_outputs) == len(model_outputs)
# For training
# len(orig_model_outputs) <= len(model_outputs)
# During training, most of the time the model_outputs starts with
# original module's outputs followed by saved activations.
# But this can be not true if the model have inplace updated tensors.
# AOTAutograd will make those tensors being returned before the original
# module's output.
# To make things safe, we'll use original_output_start_index field
# set by AOTAutograd to decide where the original module outputs start.
orig_output_end_idx = (
original_output_start_index + num_orig_model_outputs
)
# Sanity check: we are about to splice out the "user" outputs from the full set
# of "graph" outputs. Make sure we're within bounds.
assert orig_output_end_idx <= num_model_outputs
model_outputs_node.meta["user_visible_output_idxs"] = [
idx
for idx in range(
original_output_start_index, orig_output_end_idx
)
if isinstance(model_outputs[idx], torch.fx.Node)
]
else: else:
model_outputs_node.meta["user_visible_output_idxs"] = [] num_orig_model_outputs = get_num_model_outputs(gm)
return compile_fx_forward(
# We also mark the invoke_subgraph outputs as user_visible to
# force the outputs of invoke_subgraph subgraph to follow the
# original strides
_recursive_record_user_visible_output_idxs(gm)
return inner_compile(
gm, gm,
example_inputs, example_inputs,
static_input_idxs=get_static_input_idxs(fixed), num_orig_model_outputs=num_orig_model_outputs,
cudagraphs=cudagraphs, num_example_inputs=num_example_inputs,
graph_id=graph_id, compiler_config_extra=compiler_config_extra,
inner_compile=inner_compile,
is_inference=is_inference, is_inference=is_inference,
boxed_forward_device_index=forward_device,
) )
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = ( fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
@ -2460,9 +2561,9 @@ def compile_fx(
dynamo_model=model_, dynamo_model=model_,
num_example_inputs=num_example_inputs, num_example_inputs=num_example_inputs,
inner_compile=inner_compile, inner_compile=inner_compile,
cudagraphs=cudagraphs, cudagraphs=compiler_config_extra.cudagraphs,
graph_id=graph_id, graph_id=compiler_config_extra.graph_id,
forward_device=forward_device, forward_device=compiler_config_extra.forward_device,
) )
else: else:
inference_compiler = functools.partial(fw_compiler_base, is_inference=True) inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
@ -2474,38 +2575,15 @@ def compile_fx(
def bw_compiler( def bw_compiler(
gm: GraphModule, example_inputs: Sequence[InputType] gm: GraphModule, example_inputs: Sequence[InputType]
) -> OutputCode: ) -> OutputCode:
from torch._dynamo.convert_frame import compile_lock
with ( with (
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"), dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
compile_lock,
): ):
model_outputs_node = output_node(gm) return compile_fx_backward(
if config.bw_outputs_user_visible: gm,
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) example_inputs,
model_outputs_node.meta["user_visible_output_idxs"] = [ compiler_config_extra=compiler_config_extra,
idx inner_compile=inner_compile,
for idx, n in enumerate(model_outputs) )
if isinstance(n, torch.fx.Node)
]
else:
model_outputs_node.meta["user_visible_output_idxs"] = []
fixed = count_tangents(gm)
with (
config.patch(get_cpp_wrapper_config())
if config.cpp_wrapper
else contextlib.nullcontext()
):
return inner_compile(
gm,
example_inputs,
static_input_idxs=list(range(fixed)),
cudagraphs=cudagraphs,
is_backward=True,
graph_id=graph_id,
boxed_forward_device_index=forward_device,
)
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler) bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
@ -2592,8 +2670,8 @@ def compile_fx(
decompositions=decompositions, decompositions=decompositions,
partition_fn=partition_fn, partition_fn=partition_fn,
keep_inference_input_mutations=True, keep_inference_input_mutations=True,
cudagraphs=cudagraphs, cudagraphs=compiler_config_extra.cudagraphs,
boxed_forward_device_index=forward_device, boxed_forward_device_index=compiler_config_extra.forward_device,
ignore_shape_env=ignore_shape_env, ignore_shape_env=ignore_shape_env,
)(model_, example_inputs_) )(model_, example_inputs_)
except ShortenTraceback as e: except ShortenTraceback as e: