[stage 2c] make autograd and inference functions (#165668)

Add final stage of aot_stage2_compile for autograd and inference.

Differential Revision: D84844699

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165668
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
This commit is contained in:
Avik Chaudhuri 2025-10-20 23:50:31 +00:00 committed by PyTorch MergeBot
parent e20c9bf288
commit 259cb945f5

View File

@ -361,6 +361,32 @@ def aot_stage2_inference(
aot_config,
)
entry = _cache_inference_info(
aot_config,
fw_metadata,
maybe_subclass_meta,
compiled_fw,
aot_forward_graph_str,
wrappers,
)
return _aot_stage2c_make_inference_function(
aot_config,
fw_metadata,
compiled_fw,
wrappers,
entry,
)
def _cache_inference_info(
aot_config,
fw_metadata,
maybe_subclass_meta,
compiled_fw,
aot_forward_graph_str,
wrappers,
):
make_runtime_safe(fw_metadata, maybe_subclass_meta)
cache_info = aot_config.cache_info
@ -371,33 +397,47 @@ def aot_stage2_inference(
else:
return hasattr(compiled_fw, "_fx_graph_cache_key")
if cache_info is not None:
if should_save_cache():
time_taken_ns = time.time_ns() - cache_info.start_time_ns
guards_expr = AOTAutogradCache.generate_guards_expression(cache_info)
entry = AOTAutogradCache.make_entry(
compiled_fw_func=compiled_fw, # type: ignore[arg-type]
compiled_bw_func=None,
aot_joint_graph_str=None,
aot_forward_graph_str=aot_forward_graph_str,
aot_backward_graph_str=None,
runtime_metadata=fw_metadata,
dispatch_wrappers=wrappers,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=None,
indices_of_inps_to_detach=[],
forward_time_taken_ns=time_taken_ns,
backward_time_taken_ns=0,
sanitized_aot_config=sanitize_aot_config(aot_config),
guards_expr=guards_expr,
backward_state_indices=None,
num_symints_saved_for_bw=None,
serialized_bw_module=None,
)
AOTAutogradCache.save(
cache_info.cache_key, entry, remote=should_use_remote_autograd_cache()
)
compiled_fw = SerializableCompiledFunction(compiled_fw, lambda: entry)
entry: Optional[GenericAOTAutogradCacheEntry] = None
if cache_info is not None and should_save_cache():
time_taken_ns = time.time_ns() - cache_info.start_time_ns
guards_expr = AOTAutogradCache.generate_guards_expression(cache_info)
entry = AOTAutogradCache.make_entry(
compiled_fw_func=compiled_fw, # type: ignore[arg-type]
compiled_bw_func=None,
aot_joint_graph_str=None,
aot_forward_graph_str=aot_forward_graph_str,
aot_backward_graph_str=None,
runtime_metadata=fw_metadata,
dispatch_wrappers=wrappers,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=None,
indices_of_inps_to_detach=[],
forward_time_taken_ns=time_taken_ns,
backward_time_taken_ns=0,
sanitized_aot_config=sanitize_aot_config(aot_config),
guards_expr=guards_expr,
backward_state_indices=None,
num_symints_saved_for_bw=None,
serialized_bw_module=None,
)
AOTAutogradCache.save(
cache_info.cache_key,
entry,
remote=should_use_remote_autograd_cache(),
)
return entry
def _aot_stage2c_make_inference_function(
aot_config,
fw_metadata,
compiled_fw,
wrappers,
entry,
):
if entry is not None:
compiled_fw = SerializableCompiledFunction(compiled_fw, lambda: entry)
disable_amp = torch._C._is_any_autocast_enabled()
compiled_fn = RuntimeWrapper(
@ -1701,7 +1741,7 @@ def _aot_stage2b_bw_compile(
fwd_output_strides: Optional[list[Optional[tuple[int, ...]]]],
num_symints_saved_for_bw: int,
aot_config: AOTConfig,
) -> tuple[list[object], Optional[Callable]]:
) -> tuple[AutogradLazyBackwardCompileInfo, Optional[Callable]]:
"""
Compile the backward graph. Returns:
- the placeholder list for the backward graph
@ -1828,7 +1868,17 @@ def _aot_stage2b_bw_compile(
_LazyGraphModule.force_recompile(bw_module)
return placeholder_list, compiled_bw_func
saved_context = TracingContext.try_get()
saved_compile_context = CompileContext.try_get()
lazy_backward_info = AutogradLazyBackwardCompileInfo(
bw_module,
placeholder_list,
saved_context,
saved_compile_context,
)
return lazy_backward_info, compiled_bw_func
def aot_stage2_autograd(
@ -1878,7 +1928,7 @@ def aot_stage2_autograd(
aot_config,
)
placeholder_list, compiled_bw_func = _aot_stage2b_bw_compile(
lazy_backward_info, compiled_bw_func = _aot_stage2b_bw_compile(
bw_module,
maybe_subclass_meta,
fw_metadata,
@ -1887,28 +1937,119 @@ def aot_stage2_autograd(
aot_config,
)
saved_context = TracingContext.try_get()
saved_compile_context = CompileContext.try_get()
try_save_cache_entry, entry = _cache_autograd_info(
aot_config,
aot_state.flat_args,
compiled_fw_func,
compiled_bw_func,
fw_module_str,
bw_module_str,
joint_graph_str,
aot_graph_capture.wrappers,
maybe_subclass_meta,
fw_metadata,
num_fw_outs_saved_for_bw,
_indices_of_inps_to_detach,
num_symints_saved_for_bw,
bw_module,
)
flat_args = aot_state.flat_args
return _aot_stage2c_make_autograd_function(
aot_config,
aot_state.flat_args,
fw_metadata,
maybe_subclass_meta,
aot_graph_capture.wrappers,
compiled_fw_func,
compiled_bw_func,
lazy_backward_info,
try_save_cache_entry,
entry,
_indices_of_inps_to_detach,
num_symints_saved_for_bw,
)
def _aot_stage2c_make_autograd_function(
aot_config,
flat_args,
fw_metadata,
maybe_subclass_meta,
wrappers,
compiled_fw_func,
compiled_bw_func,
lazy_backward_info,
try_save_cache_entry,
entry,
_indices_of_inps_to_detach,
num_symints_saved_for_bw,
):
backward_state_indices = [
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
]
assert len(backward_state_indices) <= 1
lazy_backward_info = AutogradLazyBackwardCompileInfo(
bw_module,
placeholder_list,
saved_context,
saved_compile_context,
disable_amp = torch._C._is_any_autocast_enabled()
compiled_fn = AOTDispatchAutograd.post_compile(
compiled_fw_func,
compiled_bw_func,
maybe_subclass_meta,
num_symints_saved_for_bw,
backward_state_indices,
disable_amp,
_indices_of_inps_to_detach,
lazy_backward_info,
aot_config,
fw_metadata=fw_metadata,
try_save_cache_entry=try_save_cache_entry,
)
if entry is not None:
compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: entry)
if config.debug_assert:
flat_requires_grad: list[Optional[bool]] = [
a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
]
compiled_fn = DebugAssertWrapper(
flat_requires_grad=flat_requires_grad
).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata)
compiled_fn = post_compile(
wrappers,
compiled_fn,
aot_config,
runtime_metadata=fw_metadata,
)
return compiled_fn
def _cache_autograd_info(
aot_config,
flat_args,
compiled_fw_func,
compiled_bw_func,
fw_module_str,
bw_module_str,
joint_graph_str,
wrappers,
maybe_subclass_meta,
fw_metadata,
num_fw_outs_saved_for_bw,
_indices_of_inps_to_detach,
num_symints_saved_for_bw,
bw_module,
):
backward_state_indices = [
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
]
assert len(backward_state_indices) <= 1
make_runtime_safe(fw_metadata, maybe_subclass_meta)
try_save_cache_entry: Optional[Callable] = None
entry: Optional[GenericAOTAutogradCacheEntry] = None
wrappers = aot_graph_capture.wrappers
if aot_config.cache_info is not None:
forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns
@ -1965,8 +2106,11 @@ def aot_stage2_autograd(
num_symints_saved_for_bw=num_symints_saved_for_bw,
serialized_bw_module=serialize_graph_module(bw_module),
)
remote = should_use_remote_autograd_cache()
AOTAutogradCache.save(cache_info.cache_key, entry, remote)
AOTAutogradCache.save(
cache_info.cache_key,
entry,
remote=should_use_remote_autograd_cache(),
)
return entry
return None
@ -1977,39 +2121,7 @@ def aot_stage2_autograd(
)
try_save_cache_entry = None
disable_amp = torch._C._is_any_autocast_enabled()
compiled_fn = AOTDispatchAutograd.post_compile(
compiled_fw_func,
compiled_bw_func,
maybe_subclass_meta,
num_symints_saved_for_bw,
backward_state_indices,
disable_amp,
_indices_of_inps_to_detach,
lazy_backward_info,
aot_config,
fw_metadata=fw_metadata,
try_save_cache_entry=try_save_cache_entry,
)
if entry is not None:
compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: entry)
if config.debug_assert:
flat_requires_grad: list[Optional[bool]] = [
a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
]
compiled_fn = DebugAssertWrapper(
flat_requires_grad=flat_requires_grad
).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata)
compiled_fn = post_compile(
wrappers,
compiled_fn,
aot_config,
runtime_metadata=fw_metadata,
)
return compiled_fn
return try_save_cache_entry, entry
def _aot_stage2b_compile_forward_or_inference(