mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[stage 2c] make autograd and inference functions (#165668)
Add final stage of aot_stage2_compile for autograd and inference. Differential Revision: D84844699 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165668 Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
This commit is contained in:
parent
e20c9bf288
commit
259cb945f5
|
|
@ -361,6 +361,32 @@ def aot_stage2_inference(
|
|||
aot_config,
|
||||
)
|
||||
|
||||
entry = _cache_inference_info(
|
||||
aot_config,
|
||||
fw_metadata,
|
||||
maybe_subclass_meta,
|
||||
compiled_fw,
|
||||
aot_forward_graph_str,
|
||||
wrappers,
|
||||
)
|
||||
|
||||
return _aot_stage2c_make_inference_function(
|
||||
aot_config,
|
||||
fw_metadata,
|
||||
compiled_fw,
|
||||
wrappers,
|
||||
entry,
|
||||
)
|
||||
|
||||
|
||||
def _cache_inference_info(
|
||||
aot_config,
|
||||
fw_metadata,
|
||||
maybe_subclass_meta,
|
||||
compiled_fw,
|
||||
aot_forward_graph_str,
|
||||
wrappers,
|
||||
):
|
||||
make_runtime_safe(fw_metadata, maybe_subclass_meta)
|
||||
|
||||
cache_info = aot_config.cache_info
|
||||
|
|
@ -371,33 +397,47 @@ def aot_stage2_inference(
|
|||
else:
|
||||
return hasattr(compiled_fw, "_fx_graph_cache_key")
|
||||
|
||||
if cache_info is not None:
|
||||
if should_save_cache():
|
||||
time_taken_ns = time.time_ns() - cache_info.start_time_ns
|
||||
guards_expr = AOTAutogradCache.generate_guards_expression(cache_info)
|
||||
entry = AOTAutogradCache.make_entry(
|
||||
compiled_fw_func=compiled_fw, # type: ignore[arg-type]
|
||||
compiled_bw_func=None,
|
||||
aot_joint_graph_str=None,
|
||||
aot_forward_graph_str=aot_forward_graph_str,
|
||||
aot_backward_graph_str=None,
|
||||
runtime_metadata=fw_metadata,
|
||||
dispatch_wrappers=wrappers,
|
||||
maybe_subclass_meta=maybe_subclass_meta,
|
||||
num_fw_outs_saved_for_bw=None,
|
||||
indices_of_inps_to_detach=[],
|
||||
forward_time_taken_ns=time_taken_ns,
|
||||
backward_time_taken_ns=0,
|
||||
sanitized_aot_config=sanitize_aot_config(aot_config),
|
||||
guards_expr=guards_expr,
|
||||
backward_state_indices=None,
|
||||
num_symints_saved_for_bw=None,
|
||||
serialized_bw_module=None,
|
||||
)
|
||||
AOTAutogradCache.save(
|
||||
cache_info.cache_key, entry, remote=should_use_remote_autograd_cache()
|
||||
)
|
||||
compiled_fw = SerializableCompiledFunction(compiled_fw, lambda: entry)
|
||||
entry: Optional[GenericAOTAutogradCacheEntry] = None
|
||||
if cache_info is not None and should_save_cache():
|
||||
time_taken_ns = time.time_ns() - cache_info.start_time_ns
|
||||
guards_expr = AOTAutogradCache.generate_guards_expression(cache_info)
|
||||
entry = AOTAutogradCache.make_entry(
|
||||
compiled_fw_func=compiled_fw, # type: ignore[arg-type]
|
||||
compiled_bw_func=None,
|
||||
aot_joint_graph_str=None,
|
||||
aot_forward_graph_str=aot_forward_graph_str,
|
||||
aot_backward_graph_str=None,
|
||||
runtime_metadata=fw_metadata,
|
||||
dispatch_wrappers=wrappers,
|
||||
maybe_subclass_meta=maybe_subclass_meta,
|
||||
num_fw_outs_saved_for_bw=None,
|
||||
indices_of_inps_to_detach=[],
|
||||
forward_time_taken_ns=time_taken_ns,
|
||||
backward_time_taken_ns=0,
|
||||
sanitized_aot_config=sanitize_aot_config(aot_config),
|
||||
guards_expr=guards_expr,
|
||||
backward_state_indices=None,
|
||||
num_symints_saved_for_bw=None,
|
||||
serialized_bw_module=None,
|
||||
)
|
||||
AOTAutogradCache.save(
|
||||
cache_info.cache_key,
|
||||
entry,
|
||||
remote=should_use_remote_autograd_cache(),
|
||||
)
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def _aot_stage2c_make_inference_function(
|
||||
aot_config,
|
||||
fw_metadata,
|
||||
compiled_fw,
|
||||
wrappers,
|
||||
entry,
|
||||
):
|
||||
if entry is not None:
|
||||
compiled_fw = SerializableCompiledFunction(compiled_fw, lambda: entry)
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
compiled_fn = RuntimeWrapper(
|
||||
|
|
@ -1701,7 +1741,7 @@ def _aot_stage2b_bw_compile(
|
|||
fwd_output_strides: Optional[list[Optional[tuple[int, ...]]]],
|
||||
num_symints_saved_for_bw: int,
|
||||
aot_config: AOTConfig,
|
||||
) -> tuple[list[object], Optional[Callable]]:
|
||||
) -> tuple[AutogradLazyBackwardCompileInfo, Optional[Callable]]:
|
||||
"""
|
||||
Compile the backward graph. Returns:
|
||||
- the placeholder list for the backward graph
|
||||
|
|
@ -1828,7 +1868,17 @@ def _aot_stage2b_bw_compile(
|
|||
|
||||
_LazyGraphModule.force_recompile(bw_module)
|
||||
|
||||
return placeholder_list, compiled_bw_func
|
||||
saved_context = TracingContext.try_get()
|
||||
saved_compile_context = CompileContext.try_get()
|
||||
|
||||
lazy_backward_info = AutogradLazyBackwardCompileInfo(
|
||||
bw_module,
|
||||
placeholder_list,
|
||||
saved_context,
|
||||
saved_compile_context,
|
||||
)
|
||||
|
||||
return lazy_backward_info, compiled_bw_func
|
||||
|
||||
|
||||
def aot_stage2_autograd(
|
||||
|
|
@ -1878,7 +1928,7 @@ def aot_stage2_autograd(
|
|||
aot_config,
|
||||
)
|
||||
|
||||
placeholder_list, compiled_bw_func = _aot_stage2b_bw_compile(
|
||||
lazy_backward_info, compiled_bw_func = _aot_stage2b_bw_compile(
|
||||
bw_module,
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
|
|
@ -1887,28 +1937,119 @@ def aot_stage2_autograd(
|
|||
aot_config,
|
||||
)
|
||||
|
||||
saved_context = TracingContext.try_get()
|
||||
saved_compile_context = CompileContext.try_get()
|
||||
try_save_cache_entry, entry = _cache_autograd_info(
|
||||
aot_config,
|
||||
aot_state.flat_args,
|
||||
compiled_fw_func,
|
||||
compiled_bw_func,
|
||||
fw_module_str,
|
||||
bw_module_str,
|
||||
joint_graph_str,
|
||||
aot_graph_capture.wrappers,
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
num_fw_outs_saved_for_bw,
|
||||
_indices_of_inps_to_detach,
|
||||
num_symints_saved_for_bw,
|
||||
bw_module,
|
||||
)
|
||||
|
||||
flat_args = aot_state.flat_args
|
||||
return _aot_stage2c_make_autograd_function(
|
||||
aot_config,
|
||||
aot_state.flat_args,
|
||||
fw_metadata,
|
||||
maybe_subclass_meta,
|
||||
aot_graph_capture.wrappers,
|
||||
compiled_fw_func,
|
||||
compiled_bw_func,
|
||||
lazy_backward_info,
|
||||
try_save_cache_entry,
|
||||
entry,
|
||||
_indices_of_inps_to_detach,
|
||||
num_symints_saved_for_bw,
|
||||
)
|
||||
|
||||
|
||||
def _aot_stage2c_make_autograd_function(
|
||||
aot_config,
|
||||
flat_args,
|
||||
fw_metadata,
|
||||
maybe_subclass_meta,
|
||||
wrappers,
|
||||
compiled_fw_func,
|
||||
compiled_bw_func,
|
||||
lazy_backward_info,
|
||||
try_save_cache_entry,
|
||||
entry,
|
||||
_indices_of_inps_to_detach,
|
||||
num_symints_saved_for_bw,
|
||||
):
|
||||
backward_state_indices = [
|
||||
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
|
||||
]
|
||||
assert len(backward_state_indices) <= 1
|
||||
|
||||
lazy_backward_info = AutogradLazyBackwardCompileInfo(
|
||||
bw_module,
|
||||
placeholder_list,
|
||||
saved_context,
|
||||
saved_compile_context,
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
compiled_fn = AOTDispatchAutograd.post_compile(
|
||||
compiled_fw_func,
|
||||
compiled_bw_func,
|
||||
maybe_subclass_meta,
|
||||
num_symints_saved_for_bw,
|
||||
backward_state_indices,
|
||||
disable_amp,
|
||||
_indices_of_inps_to_detach,
|
||||
lazy_backward_info,
|
||||
aot_config,
|
||||
fw_metadata=fw_metadata,
|
||||
try_save_cache_entry=try_save_cache_entry,
|
||||
)
|
||||
|
||||
if entry is not None:
|
||||
compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: entry)
|
||||
|
||||
if config.debug_assert:
|
||||
flat_requires_grad: list[Optional[bool]] = [
|
||||
a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
|
||||
]
|
||||
compiled_fn = DebugAssertWrapper(
|
||||
flat_requires_grad=flat_requires_grad
|
||||
).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata)
|
||||
|
||||
compiled_fn = post_compile(
|
||||
wrappers,
|
||||
compiled_fn,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
return compiled_fn
|
||||
|
||||
|
||||
def _cache_autograd_info(
|
||||
aot_config,
|
||||
flat_args,
|
||||
compiled_fw_func,
|
||||
compiled_bw_func,
|
||||
fw_module_str,
|
||||
bw_module_str,
|
||||
joint_graph_str,
|
||||
wrappers,
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
num_fw_outs_saved_for_bw,
|
||||
_indices_of_inps_to_detach,
|
||||
num_symints_saved_for_bw,
|
||||
bw_module,
|
||||
):
|
||||
backward_state_indices = [
|
||||
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
|
||||
]
|
||||
assert len(backward_state_indices) <= 1
|
||||
|
||||
make_runtime_safe(fw_metadata, maybe_subclass_meta)
|
||||
|
||||
try_save_cache_entry: Optional[Callable] = None
|
||||
entry: Optional[GenericAOTAutogradCacheEntry] = None
|
||||
|
||||
wrappers = aot_graph_capture.wrappers
|
||||
if aot_config.cache_info is not None:
|
||||
forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns
|
||||
|
||||
|
|
@ -1965,8 +2106,11 @@ def aot_stage2_autograd(
|
|||
num_symints_saved_for_bw=num_symints_saved_for_bw,
|
||||
serialized_bw_module=serialize_graph_module(bw_module),
|
||||
)
|
||||
remote = should_use_remote_autograd_cache()
|
||||
AOTAutogradCache.save(cache_info.cache_key, entry, remote)
|
||||
AOTAutogradCache.save(
|
||||
cache_info.cache_key,
|
||||
entry,
|
||||
remote=should_use_remote_autograd_cache(),
|
||||
)
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
|
@ -1977,39 +2121,7 @@ def aot_stage2_autograd(
|
|||
)
|
||||
try_save_cache_entry = None
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
compiled_fn = AOTDispatchAutograd.post_compile(
|
||||
compiled_fw_func,
|
||||
compiled_bw_func,
|
||||
maybe_subclass_meta,
|
||||
num_symints_saved_for_bw,
|
||||
backward_state_indices,
|
||||
disable_amp,
|
||||
_indices_of_inps_to_detach,
|
||||
lazy_backward_info,
|
||||
aot_config,
|
||||
fw_metadata=fw_metadata,
|
||||
try_save_cache_entry=try_save_cache_entry,
|
||||
)
|
||||
|
||||
if entry is not None:
|
||||
compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: entry)
|
||||
|
||||
if config.debug_assert:
|
||||
flat_requires_grad: list[Optional[bool]] = [
|
||||
a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
|
||||
]
|
||||
compiled_fn = DebugAssertWrapper(
|
||||
flat_requires_grad=flat_requires_grad
|
||||
).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata)
|
||||
|
||||
compiled_fn = post_compile(
|
||||
wrappers,
|
||||
compiled_fn,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
return compiled_fn
|
||||
return try_save_cache_entry, entry
|
||||
|
||||
|
||||
def _aot_stage2b_compile_forward_or_inference(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user