outline various stages from aot stage2 compile (#164808)

Splits the training and inference paths for aot stage2 compile.
1. Split `aot_stage2_autograd` into `_aot_stage2a_partition`, `_aot_stage2b_fw_compile` and `_aot_stage2b_bw_compile`, and rest.
2. Split `aot_stage2_inference` into `_aot_stage2b_inference_compile` and rest.
I'm leaving these as functions with underscore names since the I/O interfaces and the exact boundaries of these splits are somewhat in the air.

Differential Revision: D84028203

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164808
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Avik Chaudhuri 2025-10-10 17:04:36 +00:00 committed by PyTorch MergeBot
parent d41aa187ec
commit 3ed90f5a09

View File

@ -84,6 +84,7 @@ from .schemas import (
FlatFn,
FxValue,
MutationType,
SubclassMeta,
ViewAndMutationMeta,
)
from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta
@ -235,6 +236,28 @@ def sanitize_aot_config(input: AOTConfig) -> AOTConfig:
)
def _get_inner_meta(
maybe_subclass_meta: Optional[SubclassMeta],
fw_metadata: ViewAndMutationMeta,
) -> ViewAndMutationMeta:
"""
Util to get view and mutation metadata.
"""
return (
fw_metadata if maybe_subclass_meta is None else maybe_subclass_meta.fw_metadata
)
def _apply_tensorify_python_scalars(module: torch.fx.GraphModule) -> None:
"""
Util to apply tensorify_python_scalars.
"""
# TODO(anijain2305) - Add tensorify_python_scalars to the HOP graph passes.
fake_mode = detect_fake_mode()
if fake_mode is not None and fake_mode.shape_env is not None:
tensorify_python_scalars(module, fake_mode.shape_env, fake_mode)
def aot_stage2_compile(
aot_state: AOTState,
aot_graph_capture: AOTGraphCapture,
@ -259,6 +282,125 @@ def aot_stage2_compile(
return aot_stage2_inference(aot_state, aot_graph_capture)
def _log_inference_graph(
fw_module: torch.fx.GraphModule,
aot_config: AOTConfig,
) -> Optional[str]:
"""
Log the inference graph to the structured logger.
Return a str representation of the graph.
"""
if aot_config.enable_log:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "torch._functorch.config",
"encoding": "string",
},
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
)
# Save the forward_graph_str right after aot_dispatch_base_graph,
# to save in the cache
aot_forward_graph_str = None
if aot_config.cache_info is not None:
aot_forward_graph_str = fw_module.print_readable(
print_output=False,
include_stride=True,
include_device=True,
fast_sympy_print=True,
expanded_def=True,
)
return aot_forward_graph_str
def _aot_stage2b_inference_compile(
fw_module: torch.fx.GraphModule,
updated_flat_args: list[Any],
maybe_subclass_meta: Optional[SubclassMeta],
fw_metadata: ViewAndMutationMeta,
aot_config,
) -> Callable:
"""
Compile the inference graph. Returns the compiled inference function.
Mostly this is very similar to _aot_stage2b_fw_compile.
Before compiling, we run pre_compile for the following wrappers:
- FakifiedOutWrapper
- FunctionalizedRngRuntimeWrapper
After compiling, we run post_compile for the following wrappers:
- EffectTokensWrapper
- AOTDispatchSubclassWrapper
- FunctionalizedRngRuntimeWrapper
- FakifiedOutWrapper
"""
disable_amp = torch._C._is_any_autocast_enabled()
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "inference"):
fakified_out_wrapper = FakifiedOutWrapper()
fakified_out_wrapper.pre_compile(
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
)
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper()
functionalized_rng_wrapper.pre_compile(
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
)
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = _get_inner_meta(
maybe_subclass_meta,
fw_metadata,
)
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw = aot_config.inference_compiler(fw_module, updated_flat_args)
# However, RuntimeWrapper does not expect the rng offsets in the
# output. So, we have to create another wrapper and take out the offset. As
# a result, we have to account for not boxed_call compilers as well.
if not getattr(compiled_fw, "_boxed_call", False):
compiled_fw = make_boxed_func(compiled_fw)
if fakified_out_wrapper.needs_post_compile:
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
compiled_fw = EffectTokensWrapper().post_compile(
compiled_fw,
aot_config,
runtime_metadata=fw_metadata,
)
# Why do we need to pass in num_fw_outs_saved_for_bw?
# See Note: [Partitioner handling for Subclasses, Part 2]
compiled_fw = AOTDispatchSubclassWrapper(
trace_joint=False,
# TODO: once we use pre_compile this will be flat_fn at the top of this function
fw_only=None,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=None,
).post_compile(
compiled_fw,
aot_config, # not used
runtime_metadata=fw_metadata,
)
# Create a wrapper to set up the rng functionalize and fakified out bits
compiled_fw = functionalized_rng_wrapper.post_compile(
compiled_fw, aot_config, runtime_metadata=fw_metadata
)
compiled_fw = fakified_out_wrapper.post_compile(
compiled_fw,
aot_config,
runtime_metadata=fw_metadata,
)
return compiled_fw
def aot_stage2_inference(
aot_state: AOTState,
aot_graph_capture: AOTGraphCapture,
@ -275,100 +417,17 @@ def aot_stage2_inference(
maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta
CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="inference")
aot_forward_graph_str = _log_inference_graph(fw_module, aot_config)
# Save the forward_graph_str right after aot_dispatch_base_graph,
# to save in the cache
aot_forward_graph_str = None
if aot_config.cache_info is not None:
aot_forward_graph_str = fw_module.print_readable(
print_output=False,
include_stride=True,
include_device=True,
fast_sympy_print=True,
expanded_def=True,
)
fakified_out_wrapper = FakifiedOutWrapper()
fakified_out_wrapper.pre_compile(
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
)
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper()
functionalized_rng_wrapper.pre_compile(
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
)
assert isinstance(fw_module, GraphModule)
_apply_tensorify_python_scalars(fw_module)
if aot_config.enable_log:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "torch._functorch.config",
"encoding": "string",
},
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
)
disable_amp = torch._C._is_any_autocast_enabled()
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "inference"):
compiler = (
aot_config.inference_compiler
if aot_config.inference_compiler is not None
else aot_config.fw_compiler
)
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = (
fw_metadata
if maybe_subclass_meta is None
else maybe_subclass_meta.fw_metadata
)
with TracingContext.report_output_strides() as fwd_output_strides:
fake_mode = detect_fake_mode()
if fake_mode is not None and fake_mode.shape_env is not None:
tensorify_python_scalars(fw_module, fake_mode.shape_env, fake_mode)
compiled_fw = compiler(fw_module, updated_flat_args)
# However, RuntimeWrapper does not expect the rng offsets in the
# output. So, we have to create another wrapper and take out the offset. As
# a result, we have to account for not boxed_call compilers as well.
if not getattr(compiled_fw, "_boxed_call", False):
compiled_fw = make_boxed_func(compiled_fw)
if fakified_out_wrapper.needs_post_compile:
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
compiled_fw = EffectTokensWrapper().post_compile(
compiled_fw,
compiled_fw = _aot_stage2b_inference_compile(
fw_module,
updated_flat_args, # type: ignore[arg-type]
maybe_subclass_meta,
fw_metadata,
aot_config,
runtime_metadata=fw_metadata,
)
# Why do we need to pass in num_fw_outs_saved_for_bw?
# See Note: [Partitioner handling for Subclasses, Part 2]
compiled_fw = AOTDispatchSubclassWrapper(
trace_joint=False,
# TODO: once we use pre_compile this will be flat_fn at the top of this function
fw_only=None,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=None,
).post_compile(
compiled_fw,
aot_config, # not used
runtime_metadata=fw_metadata,
)
# Create a wrapper to set up the rng functionalize and fakified out bits
compiled_fw = functionalized_rng_wrapper.post_compile(
compiled_fw, aot_config, runtime_metadata=fw_metadata
)
compiled_fw = fakified_out_wrapper.post_compile(
compiled_fw,
aot_config,
runtime_metadata=fw_metadata,
)
make_runtime_safe(fw_metadata, maybe_subclass_meta)
@ -409,6 +468,7 @@ def aot_stage2_inference(
)
compiled_fw = SerializableCompiledFunction(compiled_fw, lambda: entry)
disable_amp = torch._C._is_any_autocast_enabled()
compiled_fn = RuntimeWrapper(
indices_of_inps_to_detach=[],
trace_joint=False,
@ -1332,27 +1392,14 @@ def maybe_inline_graph_saved_tensors_hooks(
bw_module.recompile()
def aot_stage2_autograd(
aot_state: AOTState,
aot_graph_capture: AOTGraphCapture,
) -> DispatchReturn:
def _log_joint_graph(
fx_g: torch.fx.GraphModule,
aot_config: AOTConfig,
) -> Optional[str]:
"""
Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers,
and returns a wrapped torch.autograd.Function with a forward and backward.
Log the joint graph to the structured logger.
Return a str representation of the graph.
"""
wrappers = aot_graph_capture.wrappers
fx_g = aot_graph_capture.graph_module
flat_args = aot_state.flat_args
joint_inputs = aot_graph_capture.updated_flat_args
maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta
aot_config = aot_state.aot_config
fw_metadata = aot_state.fw_metadata
CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="autograd")
# Copied from aot_dispatch_autograd_graph.
disable_amp = torch._C._is_any_autocast_enabled()
joint_graph_str = None
if aot_config.enable_log:
aot_joint_log.info(
@ -1376,13 +1423,120 @@ def aot_stage2_autograd(
"aot_joint_graph",
payload_fn=lambda: joint_graph_str,
)
return joint_graph_str
def _log_fw_bw_graphs(
fw_module: torch.fx.GraphModule,
bw_module: torch.fx.GraphModule,
maybe_subclass_meta: Optional[SubclassMeta],
fw_metadata: ViewAndMutationMeta,
aot_config: AOTConfig,
) -> tuple[Optional[str], Optional[str]]:
"""
Log the fw and bw graphs to the structured logger.
Return str representations of the graphs.
"""
fw_module_str = None
bw_module_str = None
if aot_config.enable_log:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "torch._functorch.config",
"encoding": "string",
},
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
)
aot_graphs_log.info(
"aot_config id: %s, fw_metadata=%s, inner_meta=%s",
str(aot_config.aot_id),
str(fw_metadata),
str(_get_inner_meta(maybe_subclass_meta, fw_metadata)),
)
aot_graphs_log.info(
"%s",
lazy_format_graph_code(
"Forward graph",
fw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
aot_graphs_log.info(
"%s",
lazy_format_graph_code(
"Backward graph",
bw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
fw_module_str = fw_module.print_readable(
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
)
bw_module_str = bw_module.print_readable(
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aot_forward_graph_fw_metadata",
"encoding": "string",
},
payload_fn=lambda: dataclass_repr(fw_metadata),
)
if maybe_subclass_meta is not None:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aot_forward_graph_fw_subclass_metadata",
"encoding": "string",
},
payload_fn=lambda: dataclass_repr(maybe_subclass_meta),
)
trace_structured(
"aot_forward_graph",
payload_fn=lambda: fw_module_str,
)
trace_structured(
"aot_backward_graph",
payload_fn=lambda: bw_module_str,
)
return fw_module_str, bw_module_str
def _aot_stage2a_partition(
fx_g: torch.fx.GraphModule,
joint_inputs: Union[list[Any], tuple[list[Any], list[Any]]],
maybe_subclass_meta: Optional[SubclassMeta],
fw_metadata: ViewAndMutationMeta,
aot_config: AOTConfig,
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule, int, int, list[int], list[Any]]:
"""
Partition the joint graph into a forward graph and a backward graph. Returns:
- the forward and backward graphs
- the number of forward outputs and the number of symints saved for backward
- indices of inputs to detach
- adjusted inputs to forward
"""
disable_amp = torch._C._is_any_autocast_enabled()
inner_meta = _get_inner_meta(maybe_subclass_meta, fw_metadata)
with torch.no_grad():
inner_meta = (
fw_metadata
if maybe_subclass_meta is None
else maybe_subclass_meta.fw_metadata
)
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "joint"):
# See Note: [Partitioner handling for Subclasses, Part 1]
@ -1401,13 +1555,8 @@ def aot_stage2_autograd(
+ inner_meta.num_outputs_rng_offset
+ num_tokens # See Note [Side-Effectful Tokens in AOTAutograd]
)
fake_mode = detect_fake_mode()
fx_g = run_joint_graph_passes_on_hops(fx_g, joint_inputs, aot_config)
# TODO(anijain2305) - Add tensorify_python_scalars to the HOP graph passes.
if fake_mode is not None and fake_mode.shape_env is not None:
tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode)
# apply joint_gm callback here
if callable(torch._functorch.config.joint_custom_pass):
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
@ -1475,9 +1624,9 @@ def aot_stage2_autograd(
if dynamic_dims:
fw_metadata.dynamic_saved_tensors_idxs[idx] = dynamic_dims
fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
fw_metadata.num_symints_saved_for_bw = num_symints_saved_for_bw
inner_meta.num_symints_saved_for_bw = num_symints_saved_for_bw
if torch._functorch.config.donated_buffer:
fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs(
fw_module,
@ -1486,22 +1635,6 @@ def aot_stage2_autograd(
)
inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs
if aot_config.enable_log:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "torch._functorch.config",
"encoding": "string",
},
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
)
aot_graphs_log.info(
"aot_config id: %s, fw_metadata=%s, inner_meta=%s",
str(aot_config.aot_id),
str(fw_metadata),
str(inner_meta),
)
# Note [Detaching inputs that never need gradients]
# See https://github.com/pytorch/pytorch/issues/97745
# Suppose we have a function like this that we want to compile:
@ -1598,78 +1731,44 @@ def aot_stage2_autograd(
if bw_out is None and not metadata_mutation_in_graph and is_non_leaf:
_indices_of_inps_to_detach.append(i)
fw_module_str = None
bw_module_str = None
if aot_config.enable_log:
aot_graphs_log.info(
"%s",
lazy_format_graph_code(
"Forward graph",
fw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
aot_graphs_log.info(
"%s",
lazy_format_graph_code(
"Backward graph",
bw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
fw_module_str = fw_module.print_readable(
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
)
bw_module_str = bw_module.print_readable(
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
)
return (
fw_module,
bw_module,
num_fw_outs_saved_for_bw,
num_symints_saved_for_bw,
_indices_of_inps_to_detach,
joint_inputs[0],
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aot_forward_graph_fw_metadata",
"encoding": "string",
},
payload_fn=lambda: dataclass_repr(fw_metadata),
)
if maybe_subclass_meta is not None:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aot_forward_graph_fw_subclass_metadata",
"encoding": "string",
},
payload_fn=lambda: dataclass_repr(maybe_subclass_meta),
)
trace_structured(
"aot_forward_graph",
payload_fn=lambda: fw_module_str,
)
trace_structured(
"aot_backward_graph",
payload_fn=lambda: bw_module_str,
)
def _aot_stage2b_fw_compile(
fw_module: torch.fx.GraphModule,
adjusted_flat_args: list[Any],
maybe_subclass_meta: Optional[SubclassMeta],
fw_metadata: ViewAndMutationMeta,
num_fw_outs_saved_for_bw: int,
aot_config: AOTConfig,
) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]:
"""
Compile the forward graph. Returns:
- the output strides of the forward graph
- the compiled forward function
Before compiling, we run pre_compile for the following wrappers:
- FakifiedOutWrapper
- FunctionalizedRngRuntimeWrapper
After compiling, we run post_compile for the following wrappers:
- EffectTokensWrapper
- AOTDispatchSubclassWrapper
- FunctionalizedRngRuntimeWrapper
- FakifiedOutWrapper
"""
with torch.no_grad():
# AMP is already traced out in joint graph. we do not wish to reapply it accidentally
# in the compiler.
with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
# flat_args at this point might still be subclasses-
# make sure to pass the unwrapped fake tensors into the compiler!
adjusted_flat_args = joint_inputs[0]
fakified_out_wrapper = FakifiedOutWrapper()
fakified_out_wrapper.pre_compile(
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
@ -1679,7 +1778,7 @@ def aot_stage2_autograd(
return_new_outs=False
)
if rng_states:
if fw_metadata.num_graphsafe_rng_states > 0:
index = fw_metadata.graphsafe_rng_state_index
assert index is not None
rng_states = [
@ -1692,7 +1791,9 @@ def aot_stage2_autograd(
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
)
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = inner_meta
tracing_context.fw_metadata = _get_inner_meta(
maybe_subclass_meta, fw_metadata
)
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
@ -1729,6 +1830,23 @@ def aot_stage2_autograd(
runtime_metadata=fw_metadata,
)
return fwd_output_strides, compiled_fw_func
def _aot_stage2b_bw_compile(
bw_module: torch.fx.GraphModule,
maybe_subclass_meta: Optional[SubclassMeta],
fw_metadata: ViewAndMutationMeta,
fwd_output_strides: Optional[list[Optional[tuple[int, ...]]]],
num_symints_saved_for_bw: int,
aot_config: AOTConfig,
) -> tuple[list[object], Optional[Callable]]:
"""
Compile the backward graph. Returns:
- the placeholder list for the backward graph
- the compiled backward function
"""
with torch.no_grad():
# NB: It's important to compile backwards ahead of time, as this may
# add extra guards which we need to apply to the Dynamo cache at
# forwards
@ -1737,6 +1855,7 @@ def aot_stage2_autograd(
forward_saved_for_backwards_strides = None
if fwd_output_strides is not None:
inner_meta = _get_inner_meta(maybe_subclass_meta, fw_metadata)
forward_saved_for_backwards_strides = fwd_output_strides[
inner_meta.tensors_saved_for_backwards_slice
]
@ -1838,9 +1957,69 @@ def aot_stage2_autograd(
_LazyGraphModule.force_recompile(bw_module)
return placeholder_list, compiled_bw_func
def aot_stage2_autograd(
aot_state: AOTState,
aot_graph_capture: AOTGraphCapture,
) -> DispatchReturn:
"""
Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers,
and returns a wrapped torch.autograd.Function with a forward and backward.
"""
fx_g = aot_graph_capture.graph_module
maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta
fw_metadata = aot_state.fw_metadata
aot_config = aot_state.aot_config
CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="autograd")
joint_graph_str = _log_joint_graph(fx_g, aot_config)
_apply_tensorify_python_scalars(fx_g)
(
fw_module,
bw_module,
num_fw_outs_saved_for_bw,
num_symints_saved_for_bw,
_indices_of_inps_to_detach,
adjusted_flat_args,
) = _aot_stage2a_partition(
fx_g,
aot_graph_capture.updated_flat_args,
maybe_subclass_meta,
fw_metadata,
aot_config,
)
fw_module_str, bw_module_str = _log_fw_bw_graphs(
fw_module, bw_module, maybe_subclass_meta, fw_metadata, aot_config
)
fwd_output_strides, compiled_fw_func = _aot_stage2b_fw_compile(
fw_module,
adjusted_flat_args,
maybe_subclass_meta,
fw_metadata,
num_fw_outs_saved_for_bw,
aot_config,
)
placeholder_list, compiled_bw_func = _aot_stage2b_bw_compile(
bw_module,
maybe_subclass_meta,
fw_metadata,
fwd_output_strides,
num_symints_saved_for_bw,
aot_config,
)
saved_context = TracingContext.try_get()
saved_compile_context = CompileContext.try_get()
flat_args = aot_state.flat_args
backward_state_indices = [
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
]
@ -1858,6 +2037,7 @@ def aot_stage2_autograd(
try_save_cache_entry: Optional[Callable] = None
entry: Optional[GenericAOTAutogradCacheEntry] = None
wrappers = aot_graph_capture.wrappers
if aot_config.cache_info is not None:
forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns
@ -1926,6 +2106,7 @@ def aot_stage2_autograd(
)
try_save_cache_entry = None
disable_amp = torch._C._is_any_autocast_enabled()
compiled_fn = AOTDispatchAutograd.post_compile(
compiled_fw_func,
compiled_bw_func,