diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 707c640b908..0489bc1ba86 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -154,6 +154,8 @@ else: from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log if TYPE_CHECKING: + import types + from torch._functorch._aot_autograd.schemas import ( FQN, GraphInputName, @@ -2120,6 +2122,248 @@ def partition_fn( ) +def get_num_model_outputs(model: GraphModule) -> int: + model_outputs_node = output_node(model) + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + return len(model_outputs) + + +@dataclass(frozen=True) +class CompilerConfigExtra: + cudagraphs: BoxedBool + graph_id: int + forward_device: BoxedDeviceIndex + + +def create_compiler_config_extra(config: types.ModuleType) -> CompilerConfigExtra: + # Although cudagraphs may have been enabled via config, various + # conditions (which are tested within the bowels of Inductor) may + # force cudagraphs to be disabled. This mutable box lets us retrieve + # the final determination if cudagraphs actually can be used or not. + cudagraphs = BoxedBool(config.triton.cudagraphs) + + # TODO: The modern style is to use CompileId from TracingContext to + # identify Inductor compilation. However, this CompileId cannot + # uniquely identify multiple Inductor compilations that arise from + # DDPOptimizer + graph_id = next(_graph_counter) + + # See [Backward Generation Handling] + forward_device = BoxedDeviceIndex(None) + + return CompilerConfigExtra( + cudagraphs=cudagraphs, + graph_id=graph_id, + forward_device=forward_device, + ) + + +def compile_fx_forward( + gm: GraphModule, + example_inputs: Sequence[InputType], + num_orig_model_outputs: int, + num_example_inputs: int, + compiler_config_extra: CompilerConfigExtra, + inner_compile: Callable[..., OutputCode] = compile_fx_inner, + is_inference: bool = False, +) -> OutputCode: + """ + Compile the forward graph of the given graph module. + + Args: + gm: The graph module to compile. + example_inputs: The example inputs to use for compilation. + num_orig_model_outputs: The number of model outputs from the original dynamo graph. + num_example_inputs: The number of example inputs from the original dynamo graph. + compiler_config_extra: Extra configuration for the compiler. + inner_compile: The inner compile function to use. + is_inference: Whether this is an inference graph. + """ + + if is_inference: + # partition_fn won't be called + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_joint_graph", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + _recursive_joint_graph_passes(gm) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_joint_graph", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + fixed = torch._inductor.utils.num_fw_fixed_arguments( + num_example_inputs, len(example_inputs) + ) + + model_outputs_node = output_node(gm) + if config.keep_output_stride: + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + num_model_outputs = len(model_outputs) + + context = torch._guards.TracingContext.try_get() + # See Note [User Outputs in the inductor graph] + if context is not None and context.fw_metadata and not is_inference: + original_output_start_index = ( + context.fw_metadata.num_mutated_inp_runtime_indices + ) + else: + original_output_start_index = 0 + + assert num_orig_model_outputs <= num_model_outputs + + # Note [User Outputs in the inductor graph] + # We makes the following assumption + # For inference + # len(orig_model_outputs) == len(model_outputs) + # For training + # len(orig_model_outputs) <= len(model_outputs) + # During training, most of the time the model_outputs starts with + # original module's outputs followed by saved activations. + # But this can be not true if the model have inplace updated tensors. + # AOTAutograd will make those tensors being returned before the original + # module's output. + # To make things safe, we'll use original_output_start_index field + # set by AOTAutograd to decide where the original module outputs start. + orig_output_end_idx = original_output_start_index + num_orig_model_outputs + # Sanity check: we are about to splice out the "user" outputs from the full set + # of "graph" outputs. Make sure we're within bounds. + assert orig_output_end_idx <= num_model_outputs + + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx + for idx in range(original_output_start_index, orig_output_end_idx) + if isinstance(model_outputs[idx], torch.fx.Node) + ] + else: + model_outputs_node.meta["user_visible_output_idxs"] = [] + + # We also mark the invoke_subgraph outputs as user_visible to + # force the outputs of invoke_subgraph subgraph to follow the + # original strides + _recursive_record_user_visible_output_idxs(gm) + + return inner_compile( + gm, + example_inputs, + static_input_idxs=get_static_input_idxs(fixed), + cudagraphs=compiler_config_extra.cudagraphs, + graph_id=compiler_config_extra.graph_id, + is_inference=is_inference, + boxed_forward_device_index=compiler_config_extra.forward_device, + ) + + +def compile_fx_backward( + gm: GraphModule, + example_inputs: Sequence[InputType], + compiler_config_extra: CompilerConfigExtra, + inner_compile: Callable[..., OutputCode] = compile_fx_inner, +) -> OutputCode: + """ + Compile the backward graph of the given graph module. + + Args: + gm: The graph module to compile. + example_inputs: The example inputs to use for compilation. + compiler_config_extra: Extra configuration for the compiler. + inner_compile: The inner compile function to use. + """ + from torch._dynamo.convert_frame import compile_lock + + with compile_lock: + model_outputs_node = output_node(gm) + if config.bw_outputs_user_visible: + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx + for idx, n in enumerate(model_outputs) + if isinstance(n, torch.fx.Node) + ] + else: + model_outputs_node.meta["user_visible_output_idxs"] = [] + + fixed = count_tangents(gm) + with ( + config.patch(get_cpp_wrapper_config()) + if config.cpp_wrapper + else contextlib.nullcontext() + ): + return inner_compile( + gm, + example_inputs, + static_input_idxs=list(range(fixed)), + cudagraphs=compiler_config_extra.cudagraphs, + is_backward=True, + graph_id=compiler_config_extra.graph_id, + boxed_forward_device_index=compiler_config_extra.forward_device, + ) + + +def run_pre_grad_passes( + model_: GraphModule, example_inputs_: Sequence[InputType] +) -> GraphModule: + # "before_pre_grad_graph" is used in inductor provenance + # tracking highlighter front-end. + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_pre_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: model_.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + f"\n\n # graph id: {id(model_.graph)}", + ) + pre_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "BEFORE PRE GRAD", + model_, + include_stride=True, + include_device=True, + colored=True, + ), + ) + torch._inductor.debug._pre_grad_graph_id = id(model_.graph) + + if config.trace.provenance_tracking_level == 1: + for node in model_.graph.nodes: + if node.stack_trace: + torch._inductor.debug._inductor_pre_grad_node_stack_trace[node.name] = ( + node.stack_trace + ) + + model_ = _recursive_pre_grad_passes(model_, example_inputs_) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_pre_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: model_.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + f"\n\n # graph id: {id(model_.graph)}", + ) + return model_ + + def compile_fx( model_: GraphModule, example_inputs_: Sequence[InputType], @@ -2264,50 +2508,7 @@ def compile_fx( # having AOTAutograd trace it. # TODO: Get rid of this? if isinstance(model_, GraphModule): - # "before_pre_grad_graph" is used in inductor provenance - # tracking highlighter front-end. - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "before_pre_grad_graph", - "encoding": "string", - }, - payload_fn=lambda: model_.print_readable( - print_output=False, include_stride=True, include_device=True - ) - + f"\n\n # graph id: {id(model_.graph)}", - ) - pre_grad_graphs_log.debug( - "%s", - lazy_format_graph_code( - "BEFORE PRE GRAD", - model_, - include_stride=True, - include_device=True, - colored=True, - ), - ) - torch._inductor.debug._pre_grad_graph_id = id(model_.graph) - - if config.trace.provenance_tracking_level == 1: - for node in model_.graph.nodes: - if node.stack_trace: - torch._inductor.debug._inductor_pre_grad_node_stack_trace[ - node.name - ] = node.stack_trace - - model_ = _recursive_pre_grad_passes(model_, example_inputs_) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "after_pre_grad_graph", - "encoding": "string", - }, - payload_fn=lambda: model_.print_readable( - print_output=False, include_stride=True, include_device=True - ) - + f"\n\n # graph id: {id(model_.graph)}", - ) + model_ = run_pre_grad_passes(model_, example_inputs_) # TODO: Move this before recursive pre-grad passes # NB: This short circuit never occurs for Dynamo produced graphs @@ -2323,20 +2524,7 @@ def compile_fx( num_example_inputs = len(example_inputs_) - # Although cudagraphs may have been enabled via config, various - # conditions (which are tested within the bowels of Inductor) may - # force cudagraphs to be disabled. This mutable box lets us retrieve - # the final determination if cudagraphs actually can be used or not. - cudagraphs = BoxedBool(config.triton.cudagraphs) - - # See [Backward Generation Handling] - forward_device = BoxedDeviceIndex(None) - - # TODO: The modern style is to use CompileId from TracingContext to - # identify Inductor compilation. However, this CompileId cannot - # uniquely identify multiple Inductor compilations that arise from - # DDPOptimizer - graph_id = next(_graph_counter) + compiler_config_extra = create_compiler_config_extra(config) decompositions = ( decompositions if decompositions is not None else select_decomp_table() @@ -2348,105 +2536,18 @@ def compile_fx( is_inference: bool, ) -> OutputCode: with dynamo_utils.dynamo_timed("compile_fx..fw_compiler_base"): - if is_inference: - # partition_fn won't be called - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "before_joint_graph", - "encoding": "string", - }, - payload_fn=lambda: gm.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - - _recursive_joint_graph_passes(gm) - - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "after_joint_graph", - "encoding": "string", - }, - payload_fn=lambda: gm.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - - fixed = torch._inductor.utils.num_fw_fixed_arguments( - num_example_inputs, len(example_inputs) - ) - - model_outputs_node = output_node(gm) - if config.keep_output_stride: - model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) - num_model_outputs = len(model_outputs) - - context = torch._guards.TracingContext.try_get() - # See Note [User Outputs in the inductor graph] - if context is not None and context.fw_metadata and not is_inference: - original_output_start_index = ( - context.fw_metadata.num_mutated_inp_runtime_indices - ) - else: - original_output_start_index = 0 - - if isinstance(model_, GraphModule): - *_, orig_model_outputs_node = model_.graph.nodes - assert orig_model_outputs_node.op == "output" - orig_model_outputs, _ = pytree.tree_flatten( - orig_model_outputs_node.args - ) - num_orig_model_outputs = len(orig_model_outputs) - else: - num_orig_model_outputs = num_model_outputs - - assert num_orig_model_outputs <= num_model_outputs - - # Note [User Outputs in the inductor graph] - # We makes the following assumption - # For inference - # len(orig_model_outputs) == len(model_outputs) - # For training - # len(orig_model_outputs) <= len(model_outputs) - # During training, most of the time the model_outputs starts with - # original module's outputs followed by saved activations. - # But this can be not true if the model have inplace updated tensors. - # AOTAutograd will make those tensors being returned before the original - # module's output. - # To make things safe, we'll use original_output_start_index field - # set by AOTAutograd to decide where the original module outputs start. - orig_output_end_idx = ( - original_output_start_index + num_orig_model_outputs - ) - # Sanity check: we are about to splice out the "user" outputs from the full set - # of "graph" outputs. Make sure we're within bounds. - assert orig_output_end_idx <= num_model_outputs - - model_outputs_node.meta["user_visible_output_idxs"] = [ - idx - for idx in range( - original_output_start_index, orig_output_end_idx - ) - if isinstance(model_outputs[idx], torch.fx.Node) - ] + if isinstance(model_, GraphModule): + num_orig_model_outputs = get_num_model_outputs(model_) else: - model_outputs_node.meta["user_visible_output_idxs"] = [] - - # We also mark the invoke_subgraph outputs as user_visible to - # force the outputs of invoke_subgraph subgraph to follow the - # original strides - _recursive_record_user_visible_output_idxs(gm) - - return inner_compile( + num_orig_model_outputs = get_num_model_outputs(gm) + return compile_fx_forward( gm, example_inputs, - static_input_idxs=get_static_input_idxs(fixed), - cudagraphs=cudagraphs, - graph_id=graph_id, + num_orig_model_outputs=num_orig_model_outputs, + num_example_inputs=num_example_inputs, + compiler_config_extra=compiler_config_extra, + inner_compile=inner_compile, is_inference=is_inference, - boxed_forward_device_index=forward_device, ) fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = ( @@ -2460,9 +2561,9 @@ def compile_fx( dynamo_model=model_, num_example_inputs=num_example_inputs, inner_compile=inner_compile, - cudagraphs=cudagraphs, - graph_id=graph_id, - forward_device=forward_device, + cudagraphs=compiler_config_extra.cudagraphs, + graph_id=compiler_config_extra.graph_id, + forward_device=compiler_config_extra.forward_device, ) else: inference_compiler = functools.partial(fw_compiler_base, is_inference=True) @@ -2474,38 +2575,15 @@ def compile_fx( def bw_compiler( gm: GraphModule, example_inputs: Sequence[InputType] ) -> OutputCode: - from torch._dynamo.convert_frame import compile_lock - with ( dynamo_utils.dynamo_timed("compile_fx..bw_compiler"), - compile_lock, ): - model_outputs_node = output_node(gm) - if config.bw_outputs_user_visible: - model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) - model_outputs_node.meta["user_visible_output_idxs"] = [ - idx - for idx, n in enumerate(model_outputs) - if isinstance(n, torch.fx.Node) - ] - else: - model_outputs_node.meta["user_visible_output_idxs"] = [] - - fixed = count_tangents(gm) - with ( - config.patch(get_cpp_wrapper_config()) - if config.cpp_wrapper - else contextlib.nullcontext() - ): - return inner_compile( - gm, - example_inputs, - static_input_idxs=list(range(fixed)), - cudagraphs=cudagraphs, - is_backward=True, - graph_id=graph_id, - boxed_forward_device_index=forward_device, - ) + return compile_fx_backward( + gm, + example_inputs, + compiler_config_extra=compiler_config_extra, + inner_compile=inner_compile, + ) bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler) @@ -2592,8 +2670,8 @@ def compile_fx( decompositions=decompositions, partition_fn=partition_fn, keep_inference_input_mutations=True, - cudagraphs=cudagraphs, - boxed_forward_device_index=forward_device, + cudagraphs=compiler_config_extra.cudagraphs, + boxed_forward_device_index=compiler_config_extra.forward_device, ignore_shape_env=ignore_shape_env, )(model_, example_inputs_) except ShortenTraceback as e: