mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
outline various stages from aot stage2 compile (#164808)
Splits the training and inference paths for aot stage2 compile. 1. Split `aot_stage2_autograd` into `_aot_stage2a_partition`, `_aot_stage2b_fw_compile` and `_aot_stage2b_bw_compile`, and rest. 2. Split `aot_stage2_inference` into `_aot_stage2b_inference_compile` and rest. I'm leaving these as functions with underscore names since the I/O interfaces and the exact boundaries of these splits are somewhat in the air. Differential Revision: D84028203 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164808 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
parent
d41aa187ec
commit
3ed90f5a09
|
|
@ -84,6 +84,7 @@ from .schemas import (
|
|||
FlatFn,
|
||||
FxValue,
|
||||
MutationType,
|
||||
SubclassMeta,
|
||||
ViewAndMutationMeta,
|
||||
)
|
||||
from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta
|
||||
|
|
@ -235,6 +236,28 @@ def sanitize_aot_config(input: AOTConfig) -> AOTConfig:
|
|||
)
|
||||
|
||||
|
||||
def _get_inner_meta(
|
||||
maybe_subclass_meta: Optional[SubclassMeta],
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
) -> ViewAndMutationMeta:
|
||||
"""
|
||||
Util to get view and mutation metadata.
|
||||
"""
|
||||
return (
|
||||
fw_metadata if maybe_subclass_meta is None else maybe_subclass_meta.fw_metadata
|
||||
)
|
||||
|
||||
|
||||
def _apply_tensorify_python_scalars(module: torch.fx.GraphModule) -> None:
|
||||
"""
|
||||
Util to apply tensorify_python_scalars.
|
||||
"""
|
||||
# TODO(anijain2305) - Add tensorify_python_scalars to the HOP graph passes.
|
||||
fake_mode = detect_fake_mode()
|
||||
if fake_mode is not None and fake_mode.shape_env is not None:
|
||||
tensorify_python_scalars(module, fake_mode.shape_env, fake_mode)
|
||||
|
||||
|
||||
def aot_stage2_compile(
|
||||
aot_state: AOTState,
|
||||
aot_graph_capture: AOTGraphCapture,
|
||||
|
|
@ -259,6 +282,125 @@ def aot_stage2_compile(
|
|||
return aot_stage2_inference(aot_state, aot_graph_capture)
|
||||
|
||||
|
||||
def _log_inference_graph(
|
||||
fw_module: torch.fx.GraphModule,
|
||||
aot_config: AOTConfig,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Log the inference graph to the structured logger.
|
||||
Return a str representation of the graph.
|
||||
"""
|
||||
if aot_config.enable_log:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "torch._functorch.config",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
|
||||
)
|
||||
|
||||
# Save the forward_graph_str right after aot_dispatch_base_graph,
|
||||
# to save in the cache
|
||||
aot_forward_graph_str = None
|
||||
if aot_config.cache_info is not None:
|
||||
aot_forward_graph_str = fw_module.print_readable(
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
fast_sympy_print=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
|
||||
return aot_forward_graph_str
|
||||
|
||||
|
||||
def _aot_stage2b_inference_compile(
|
||||
fw_module: torch.fx.GraphModule,
|
||||
updated_flat_args: list[Any],
|
||||
maybe_subclass_meta: Optional[SubclassMeta],
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
aot_config,
|
||||
) -> Callable:
|
||||
"""
|
||||
Compile the inference graph. Returns the compiled inference function.
|
||||
|
||||
Mostly this is very similar to _aot_stage2b_fw_compile.
|
||||
|
||||
Before compiling, we run pre_compile for the following wrappers:
|
||||
- FakifiedOutWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
After compiling, we run post_compile for the following wrappers:
|
||||
- EffectTokensWrapper
|
||||
- AOTDispatchSubclassWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
- FakifiedOutWrapper
|
||||
"""
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
|
||||
with context(), track_graph_compiling(aot_config, "inference"):
|
||||
fakified_out_wrapper = FakifiedOutWrapper()
|
||||
fakified_out_wrapper.pre_compile(
|
||||
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper()
|
||||
functionalized_rng_wrapper.pre_compile(
|
||||
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
tracing_context.fw_metadata = _get_inner_meta(
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
)
|
||||
|
||||
with TracingContext.report_output_strides() as fwd_output_strides:
|
||||
compiled_fw = aot_config.inference_compiler(fw_module, updated_flat_args)
|
||||
|
||||
# However, RuntimeWrapper does not expect the rng offsets in the
|
||||
# output. So, we have to create another wrapper and take out the offset. As
|
||||
# a result, we have to account for not boxed_call compilers as well.
|
||||
if not getattr(compiled_fw, "_boxed_call", False):
|
||||
compiled_fw = make_boxed_func(compiled_fw)
|
||||
|
||||
if fakified_out_wrapper.needs_post_compile:
|
||||
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
|
||||
|
||||
compiled_fw = EffectTokensWrapper().post_compile(
|
||||
compiled_fw,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
# Why do we need to pass in num_fw_outs_saved_for_bw?
|
||||
# See Note: [Partitioner handling for Subclasses, Part 2]
|
||||
compiled_fw = AOTDispatchSubclassWrapper(
|
||||
trace_joint=False,
|
||||
# TODO: once we use pre_compile this will be flat_fn at the top of this function
|
||||
fw_only=None,
|
||||
maybe_subclass_meta=maybe_subclass_meta,
|
||||
num_fw_outs_saved_for_bw=None,
|
||||
).post_compile(
|
||||
compiled_fw,
|
||||
aot_config, # not used
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
# Create a wrapper to set up the rng functionalize and fakified out bits
|
||||
compiled_fw = functionalized_rng_wrapper.post_compile(
|
||||
compiled_fw, aot_config, runtime_metadata=fw_metadata
|
||||
)
|
||||
|
||||
compiled_fw = fakified_out_wrapper.post_compile(
|
||||
compiled_fw,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
return compiled_fw
|
||||
|
||||
|
||||
def aot_stage2_inference(
|
||||
aot_state: AOTState,
|
||||
aot_graph_capture: AOTGraphCapture,
|
||||
|
|
@ -275,100 +417,17 @@ def aot_stage2_inference(
|
|||
maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta
|
||||
|
||||
CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="inference")
|
||||
aot_forward_graph_str = _log_inference_graph(fw_module, aot_config)
|
||||
|
||||
# Save the forward_graph_str right after aot_dispatch_base_graph,
|
||||
# to save in the cache
|
||||
aot_forward_graph_str = None
|
||||
if aot_config.cache_info is not None:
|
||||
aot_forward_graph_str = fw_module.print_readable(
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
fast_sympy_print=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
|
||||
fakified_out_wrapper = FakifiedOutWrapper()
|
||||
fakified_out_wrapper.pre_compile(
|
||||
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper()
|
||||
functionalized_rng_wrapper.pre_compile(
|
||||
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
assert isinstance(fw_module, GraphModule)
|
||||
_apply_tensorify_python_scalars(fw_module)
|
||||
|
||||
if aot_config.enable_log:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "torch._functorch.config",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
|
||||
)
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
|
||||
with context(), track_graph_compiling(aot_config, "inference"):
|
||||
compiler = (
|
||||
aot_config.inference_compiler
|
||||
if aot_config.inference_compiler is not None
|
||||
else aot_config.fw_compiler
|
||||
)
|
||||
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
tracing_context.fw_metadata = (
|
||||
fw_metadata
|
||||
if maybe_subclass_meta is None
|
||||
else maybe_subclass_meta.fw_metadata
|
||||
)
|
||||
|
||||
with TracingContext.report_output_strides() as fwd_output_strides:
|
||||
fake_mode = detect_fake_mode()
|
||||
if fake_mode is not None and fake_mode.shape_env is not None:
|
||||
tensorify_python_scalars(fw_module, fake_mode.shape_env, fake_mode)
|
||||
compiled_fw = compiler(fw_module, updated_flat_args)
|
||||
|
||||
# However, RuntimeWrapper does not expect the rng offsets in the
|
||||
# output. So, we have to create another wrapper and take out the offset. As
|
||||
# a result, we have to account for not boxed_call compilers as well.
|
||||
if not getattr(compiled_fw, "_boxed_call", False):
|
||||
compiled_fw = make_boxed_func(compiled_fw)
|
||||
|
||||
if fakified_out_wrapper.needs_post_compile:
|
||||
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
|
||||
|
||||
compiled_fw = EffectTokensWrapper().post_compile(
|
||||
compiled_fw,
|
||||
compiled_fw = _aot_stage2b_inference_compile(
|
||||
fw_module,
|
||||
updated_flat_args, # type: ignore[arg-type]
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
# Why do we need to pass in num_fw_outs_saved_for_bw?
|
||||
# See Note: [Partitioner handling for Subclasses, Part 2]
|
||||
compiled_fw = AOTDispatchSubclassWrapper(
|
||||
trace_joint=False,
|
||||
# TODO: once we use pre_compile this will be flat_fn at the top of this function
|
||||
fw_only=None,
|
||||
maybe_subclass_meta=maybe_subclass_meta,
|
||||
num_fw_outs_saved_for_bw=None,
|
||||
).post_compile(
|
||||
compiled_fw,
|
||||
aot_config, # not used
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
# Create a wrapper to set up the rng functionalize and fakified out bits
|
||||
compiled_fw = functionalized_rng_wrapper.post_compile(
|
||||
compiled_fw, aot_config, runtime_metadata=fw_metadata
|
||||
)
|
||||
|
||||
compiled_fw = fakified_out_wrapper.post_compile(
|
||||
compiled_fw,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
make_runtime_safe(fw_metadata, maybe_subclass_meta)
|
||||
|
|
@ -409,6 +468,7 @@ def aot_stage2_inference(
|
|||
)
|
||||
compiled_fw = SerializableCompiledFunction(compiled_fw, lambda: entry)
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
compiled_fn = RuntimeWrapper(
|
||||
indices_of_inps_to_detach=[],
|
||||
trace_joint=False,
|
||||
|
|
@ -1332,27 +1392,14 @@ def maybe_inline_graph_saved_tensors_hooks(
|
|||
bw_module.recompile()
|
||||
|
||||
|
||||
def aot_stage2_autograd(
|
||||
aot_state: AOTState,
|
||||
aot_graph_capture: AOTGraphCapture,
|
||||
) -> DispatchReturn:
|
||||
def _log_joint_graph(
|
||||
fx_g: torch.fx.GraphModule,
|
||||
aot_config: AOTConfig,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers,
|
||||
and returns a wrapped torch.autograd.Function with a forward and backward.
|
||||
Log the joint graph to the structured logger.
|
||||
Return a str representation of the graph.
|
||||
"""
|
||||
|
||||
wrappers = aot_graph_capture.wrappers
|
||||
fx_g = aot_graph_capture.graph_module
|
||||
flat_args = aot_state.flat_args
|
||||
joint_inputs = aot_graph_capture.updated_flat_args
|
||||
maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta
|
||||
aot_config = aot_state.aot_config
|
||||
fw_metadata = aot_state.fw_metadata
|
||||
|
||||
CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="autograd")
|
||||
|
||||
# Copied from aot_dispatch_autograd_graph.
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
joint_graph_str = None
|
||||
if aot_config.enable_log:
|
||||
aot_joint_log.info(
|
||||
|
|
@ -1376,13 +1423,120 @@ def aot_stage2_autograd(
|
|||
"aot_joint_graph",
|
||||
payload_fn=lambda: joint_graph_str,
|
||||
)
|
||||
return joint_graph_str
|
||||
|
||||
|
||||
def _log_fw_bw_graphs(
|
||||
fw_module: torch.fx.GraphModule,
|
||||
bw_module: torch.fx.GraphModule,
|
||||
maybe_subclass_meta: Optional[SubclassMeta],
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
aot_config: AOTConfig,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Log the fw and bw graphs to the structured logger.
|
||||
Return str representations of the graphs.
|
||||
"""
|
||||
fw_module_str = None
|
||||
bw_module_str = None
|
||||
if aot_config.enable_log:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "torch._functorch.config",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
|
||||
)
|
||||
aot_graphs_log.info(
|
||||
"aot_config id: %s, fw_metadata=%s, inner_meta=%s",
|
||||
str(aot_config.aot_id),
|
||||
str(fw_metadata),
|
||||
str(_get_inner_meta(maybe_subclass_meta, fw_metadata)),
|
||||
)
|
||||
|
||||
aot_graphs_log.info(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"Forward graph",
|
||||
fw_module,
|
||||
aot_config.aot_id,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
colored=True,
|
||||
),
|
||||
)
|
||||
aot_graphs_log.info(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"Backward graph",
|
||||
bw_module,
|
||||
aot_config.aot_id,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
colored=True,
|
||||
),
|
||||
)
|
||||
fw_module_str = fw_module.print_readable(
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
bw_module_str = bw_module.print_readable(
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "aot_forward_graph_fw_metadata",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: dataclass_repr(fw_metadata),
|
||||
)
|
||||
if maybe_subclass_meta is not None:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "aot_forward_graph_fw_subclass_metadata",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: dataclass_repr(maybe_subclass_meta),
|
||||
)
|
||||
|
||||
trace_structured(
|
||||
"aot_forward_graph",
|
||||
payload_fn=lambda: fw_module_str,
|
||||
)
|
||||
trace_structured(
|
||||
"aot_backward_graph",
|
||||
payload_fn=lambda: bw_module_str,
|
||||
)
|
||||
return fw_module_str, bw_module_str
|
||||
|
||||
|
||||
def _aot_stage2a_partition(
|
||||
fx_g: torch.fx.GraphModule,
|
||||
joint_inputs: Union[list[Any], tuple[list[Any], list[Any]]],
|
||||
maybe_subclass_meta: Optional[SubclassMeta],
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
aot_config: AOTConfig,
|
||||
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule, int, int, list[int], list[Any]]:
|
||||
"""
|
||||
Partition the joint graph into a forward graph and a backward graph. Returns:
|
||||
- the forward and backward graphs
|
||||
- the number of forward outputs and the number of symints saved for backward
|
||||
- indices of inputs to detach
|
||||
- adjusted inputs to forward
|
||||
"""
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
inner_meta = _get_inner_meta(maybe_subclass_meta, fw_metadata)
|
||||
|
||||
with torch.no_grad():
|
||||
inner_meta = (
|
||||
fw_metadata
|
||||
if maybe_subclass_meta is None
|
||||
else maybe_subclass_meta.fw_metadata
|
||||
)
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
with context(), track_graph_compiling(aot_config, "joint"):
|
||||
# See Note: [Partitioner handling for Subclasses, Part 1]
|
||||
|
|
@ -1401,13 +1555,8 @@ def aot_stage2_autograd(
|
|||
+ inner_meta.num_outputs_rng_offset
|
||||
+ num_tokens # See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
)
|
||||
fake_mode = detect_fake_mode()
|
||||
fx_g = run_joint_graph_passes_on_hops(fx_g, joint_inputs, aot_config)
|
||||
|
||||
# TODO(anijain2305) - Add tensorify_python_scalars to the HOP graph passes.
|
||||
if fake_mode is not None and fake_mode.shape_env is not None:
|
||||
tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode)
|
||||
|
||||
# apply joint_gm callback here
|
||||
if callable(torch._functorch.config.joint_custom_pass):
|
||||
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
|
||||
|
|
@ -1475,9 +1624,9 @@ def aot_stage2_autograd(
|
|||
if dynamic_dims:
|
||||
fw_metadata.dynamic_saved_tensors_idxs[idx] = dynamic_dims
|
||||
|
||||
fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
|
||||
inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
|
||||
num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
|
||||
fw_metadata.num_symints_saved_for_bw = num_symints_saved_for_bw
|
||||
inner_meta.num_symints_saved_for_bw = num_symints_saved_for_bw
|
||||
if torch._functorch.config.donated_buffer:
|
||||
fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs(
|
||||
fw_module,
|
||||
|
|
@ -1486,22 +1635,6 @@ def aot_stage2_autograd(
|
|||
)
|
||||
inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs
|
||||
|
||||
if aot_config.enable_log:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "torch._functorch.config",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
|
||||
)
|
||||
aot_graphs_log.info(
|
||||
"aot_config id: %s, fw_metadata=%s, inner_meta=%s",
|
||||
str(aot_config.aot_id),
|
||||
str(fw_metadata),
|
||||
str(inner_meta),
|
||||
)
|
||||
|
||||
# Note [Detaching inputs that never need gradients]
|
||||
# See https://github.com/pytorch/pytorch/issues/97745
|
||||
# Suppose we have a function like this that we want to compile:
|
||||
|
|
@ -1598,78 +1731,44 @@ def aot_stage2_autograd(
|
|||
if bw_out is None and not metadata_mutation_in_graph and is_non_leaf:
|
||||
_indices_of_inps_to_detach.append(i)
|
||||
|
||||
fw_module_str = None
|
||||
bw_module_str = None
|
||||
if aot_config.enable_log:
|
||||
aot_graphs_log.info(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"Forward graph",
|
||||
fw_module,
|
||||
aot_config.aot_id,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
colored=True,
|
||||
),
|
||||
)
|
||||
aot_graphs_log.info(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"Backward graph",
|
||||
bw_module,
|
||||
aot_config.aot_id,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
colored=True,
|
||||
),
|
||||
)
|
||||
fw_module_str = fw_module.print_readable(
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
bw_module_str = bw_module.print_readable(
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
return (
|
||||
fw_module,
|
||||
bw_module,
|
||||
num_fw_outs_saved_for_bw,
|
||||
num_symints_saved_for_bw,
|
||||
_indices_of_inps_to_detach,
|
||||
joint_inputs[0],
|
||||
)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "aot_forward_graph_fw_metadata",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: dataclass_repr(fw_metadata),
|
||||
)
|
||||
if maybe_subclass_meta is not None:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "aot_forward_graph_fw_subclass_metadata",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: dataclass_repr(maybe_subclass_meta),
|
||||
)
|
||||
|
||||
trace_structured(
|
||||
"aot_forward_graph",
|
||||
payload_fn=lambda: fw_module_str,
|
||||
)
|
||||
trace_structured(
|
||||
"aot_backward_graph",
|
||||
payload_fn=lambda: bw_module_str,
|
||||
)
|
||||
def _aot_stage2b_fw_compile(
|
||||
fw_module: torch.fx.GraphModule,
|
||||
adjusted_flat_args: list[Any],
|
||||
maybe_subclass_meta: Optional[SubclassMeta],
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
num_fw_outs_saved_for_bw: int,
|
||||
aot_config: AOTConfig,
|
||||
) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]:
|
||||
"""
|
||||
Compile the forward graph. Returns:
|
||||
- the output strides of the forward graph
|
||||
- the compiled forward function
|
||||
|
||||
Before compiling, we run pre_compile for the following wrappers:
|
||||
- FakifiedOutWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
After compiling, we run post_compile for the following wrappers:
|
||||
- EffectTokensWrapper
|
||||
- AOTDispatchSubclassWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
- FakifiedOutWrapper
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# AMP is already traced out in joint graph. we do not wish to reapply it accidentally
|
||||
# in the compiler.
|
||||
with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
|
||||
# flat_args at this point might still be subclasses-
|
||||
# make sure to pass the unwrapped fake tensors into the compiler!
|
||||
adjusted_flat_args = joint_inputs[0]
|
||||
|
||||
fakified_out_wrapper = FakifiedOutWrapper()
|
||||
fakified_out_wrapper.pre_compile(
|
||||
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
|
|
@ -1679,7 +1778,7 @@ def aot_stage2_autograd(
|
|||
return_new_outs=False
|
||||
)
|
||||
|
||||
if rng_states:
|
||||
if fw_metadata.num_graphsafe_rng_states > 0:
|
||||
index = fw_metadata.graphsafe_rng_state_index
|
||||
assert index is not None
|
||||
rng_states = [
|
||||
|
|
@ -1692,7 +1791,9 @@ def aot_stage2_autograd(
|
|||
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
tracing_context.fw_metadata = inner_meta
|
||||
tracing_context.fw_metadata = _get_inner_meta(
|
||||
maybe_subclass_meta, fw_metadata
|
||||
)
|
||||
|
||||
with TracingContext.report_output_strides() as fwd_output_strides:
|
||||
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
|
||||
|
|
@ -1729,6 +1830,23 @@ def aot_stage2_autograd(
|
|||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
return fwd_output_strides, compiled_fw_func
|
||||
|
||||
|
||||
def _aot_stage2b_bw_compile(
|
||||
bw_module: torch.fx.GraphModule,
|
||||
maybe_subclass_meta: Optional[SubclassMeta],
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
fwd_output_strides: Optional[list[Optional[tuple[int, ...]]]],
|
||||
num_symints_saved_for_bw: int,
|
||||
aot_config: AOTConfig,
|
||||
) -> tuple[list[object], Optional[Callable]]:
|
||||
"""
|
||||
Compile the backward graph. Returns:
|
||||
- the placeholder list for the backward graph
|
||||
- the compiled backward function
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# NB: It's important to compile backwards ahead of time, as this may
|
||||
# add extra guards which we need to apply to the Dynamo cache at
|
||||
# forwards
|
||||
|
|
@ -1737,6 +1855,7 @@ def aot_stage2_autograd(
|
|||
|
||||
forward_saved_for_backwards_strides = None
|
||||
if fwd_output_strides is not None:
|
||||
inner_meta = _get_inner_meta(maybe_subclass_meta, fw_metadata)
|
||||
forward_saved_for_backwards_strides = fwd_output_strides[
|
||||
inner_meta.tensors_saved_for_backwards_slice
|
||||
]
|
||||
|
|
@ -1838,9 +1957,69 @@ def aot_stage2_autograd(
|
|||
|
||||
_LazyGraphModule.force_recompile(bw_module)
|
||||
|
||||
return placeholder_list, compiled_bw_func
|
||||
|
||||
|
||||
def aot_stage2_autograd(
|
||||
aot_state: AOTState,
|
||||
aot_graph_capture: AOTGraphCapture,
|
||||
) -> DispatchReturn:
|
||||
"""
|
||||
Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers,
|
||||
and returns a wrapped torch.autograd.Function with a forward and backward.
|
||||
"""
|
||||
|
||||
fx_g = aot_graph_capture.graph_module
|
||||
maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta
|
||||
fw_metadata = aot_state.fw_metadata
|
||||
aot_config = aot_state.aot_config
|
||||
|
||||
CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="autograd")
|
||||
joint_graph_str = _log_joint_graph(fx_g, aot_config)
|
||||
|
||||
_apply_tensorify_python_scalars(fx_g)
|
||||
|
||||
(
|
||||
fw_module,
|
||||
bw_module,
|
||||
num_fw_outs_saved_for_bw,
|
||||
num_symints_saved_for_bw,
|
||||
_indices_of_inps_to_detach,
|
||||
adjusted_flat_args,
|
||||
) = _aot_stage2a_partition(
|
||||
fx_g,
|
||||
aot_graph_capture.updated_flat_args,
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
)
|
||||
|
||||
fw_module_str, bw_module_str = _log_fw_bw_graphs(
|
||||
fw_module, bw_module, maybe_subclass_meta, fw_metadata, aot_config
|
||||
)
|
||||
|
||||
fwd_output_strides, compiled_fw_func = _aot_stage2b_fw_compile(
|
||||
fw_module,
|
||||
adjusted_flat_args,
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
num_fw_outs_saved_for_bw,
|
||||
aot_config,
|
||||
)
|
||||
|
||||
placeholder_list, compiled_bw_func = _aot_stage2b_bw_compile(
|
||||
bw_module,
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
fwd_output_strides,
|
||||
num_symints_saved_for_bw,
|
||||
aot_config,
|
||||
)
|
||||
|
||||
saved_context = TracingContext.try_get()
|
||||
saved_compile_context = CompileContext.try_get()
|
||||
|
||||
flat_args = aot_state.flat_args
|
||||
backward_state_indices = [
|
||||
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
|
||||
]
|
||||
|
|
@ -1858,6 +2037,7 @@ def aot_stage2_autograd(
|
|||
try_save_cache_entry: Optional[Callable] = None
|
||||
entry: Optional[GenericAOTAutogradCacheEntry] = None
|
||||
|
||||
wrappers = aot_graph_capture.wrappers
|
||||
if aot_config.cache_info is not None:
|
||||
forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns
|
||||
|
||||
|
|
@ -1926,6 +2106,7 @@ def aot_stage2_autograd(
|
|||
)
|
||||
try_save_cache_entry = None
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
compiled_fn = AOTDispatchAutograd.post_compile(
|
||||
compiled_fw_func,
|
||||
compiled_bw_func,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user