mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Lift fw_compiler and bw_compiler as toplevel functions. (#161762)
This is a no-op refactor to compiler_fx which lifts the logic of fw_compiler and bw_compiler to toplevel, so that they can be reused in a different stack (e.g. precompile). Differential Revision: [D81292968](https://our.internmc.facebook.com/intern/diff/D81292968/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161762 Approved by: https://github.com/angelayi, https://github.com/yushangdi
This commit is contained in:
parent
05eeb29976
commit
eb78757708
|
|
@ -154,6 +154,8 @@ else:
|
|||
from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import types
|
||||
|
||||
from torch._functorch._aot_autograd.schemas import (
|
||||
FQN,
|
||||
GraphInputName,
|
||||
|
|
@ -2120,6 +2122,248 @@ def partition_fn(
|
|||
)
|
||||
|
||||
|
||||
def get_num_model_outputs(model: GraphModule) -> int:
|
||||
model_outputs_node = output_node(model)
|
||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||
return len(model_outputs)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompilerConfigExtra:
|
||||
cudagraphs: BoxedBool
|
||||
graph_id: int
|
||||
forward_device: BoxedDeviceIndex
|
||||
|
||||
|
||||
def create_compiler_config_extra(config: types.ModuleType) -> CompilerConfigExtra:
|
||||
# Although cudagraphs may have been enabled via config, various
|
||||
# conditions (which are tested within the bowels of Inductor) may
|
||||
# force cudagraphs to be disabled. This mutable box lets us retrieve
|
||||
# the final determination if cudagraphs actually can be used or not.
|
||||
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
||||
|
||||
# TODO: The modern style is to use CompileId from TracingContext to
|
||||
# identify Inductor compilation. However, this CompileId cannot
|
||||
# uniquely identify multiple Inductor compilations that arise from
|
||||
# DDPOptimizer
|
||||
graph_id = next(_graph_counter)
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
forward_device = BoxedDeviceIndex(None)
|
||||
|
||||
return CompilerConfigExtra(
|
||||
cudagraphs=cudagraphs,
|
||||
graph_id=graph_id,
|
||||
forward_device=forward_device,
|
||||
)
|
||||
|
||||
|
||||
def compile_fx_forward(
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
num_orig_model_outputs: int,
|
||||
num_example_inputs: int,
|
||||
compiler_config_extra: CompilerConfigExtra,
|
||||
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
||||
is_inference: bool = False,
|
||||
) -> OutputCode:
|
||||
"""
|
||||
Compile the forward graph of the given graph module.
|
||||
|
||||
Args:
|
||||
gm: The graph module to compile.
|
||||
example_inputs: The example inputs to use for compilation.
|
||||
num_orig_model_outputs: The number of model outputs from the original dynamo graph.
|
||||
num_example_inputs: The number of example inputs from the original dynamo graph.
|
||||
compiler_config_extra: Extra configuration for the compiler.
|
||||
inner_compile: The inner compile function to use.
|
||||
is_inference: Whether this is an inference graph.
|
||||
"""
|
||||
|
||||
if is_inference:
|
||||
# partition_fn won't be called
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "before_joint_graph",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: gm.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
_recursive_joint_graph_passes(gm)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "after_joint_graph",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: gm.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
||||
num_example_inputs, len(example_inputs)
|
||||
)
|
||||
|
||||
model_outputs_node = output_node(gm)
|
||||
if config.keep_output_stride:
|
||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||
num_model_outputs = len(model_outputs)
|
||||
|
||||
context = torch._guards.TracingContext.try_get()
|
||||
# See Note [User Outputs in the inductor graph]
|
||||
if context is not None and context.fw_metadata and not is_inference:
|
||||
original_output_start_index = (
|
||||
context.fw_metadata.num_mutated_inp_runtime_indices
|
||||
)
|
||||
else:
|
||||
original_output_start_index = 0
|
||||
|
||||
assert num_orig_model_outputs <= num_model_outputs
|
||||
|
||||
# Note [User Outputs in the inductor graph]
|
||||
# We makes the following assumption
|
||||
# For inference
|
||||
# len(orig_model_outputs) == len(model_outputs)
|
||||
# For training
|
||||
# len(orig_model_outputs) <= len(model_outputs)
|
||||
# During training, most of the time the model_outputs starts with
|
||||
# original module's outputs followed by saved activations.
|
||||
# But this can be not true if the model have inplace updated tensors.
|
||||
# AOTAutograd will make those tensors being returned before the original
|
||||
# module's output.
|
||||
# To make things safe, we'll use original_output_start_index field
|
||||
# set by AOTAutograd to decide where the original module outputs start.
|
||||
orig_output_end_idx = original_output_start_index + num_orig_model_outputs
|
||||
# Sanity check: we are about to splice out the "user" outputs from the full set
|
||||
# of "graph" outputs. Make sure we're within bounds.
|
||||
assert orig_output_end_idx <= num_model_outputs
|
||||
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = [
|
||||
idx
|
||||
for idx in range(original_output_start_index, orig_output_end_idx)
|
||||
if isinstance(model_outputs[idx], torch.fx.Node)
|
||||
]
|
||||
else:
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||
|
||||
# We also mark the invoke_subgraph outputs as user_visible to
|
||||
# force the outputs of invoke_subgraph subgraph to follow the
|
||||
# original strides
|
||||
_recursive_record_user_visible_output_idxs(gm)
|
||||
|
||||
return inner_compile(
|
||||
gm,
|
||||
example_inputs,
|
||||
static_input_idxs=get_static_input_idxs(fixed),
|
||||
cudagraphs=compiler_config_extra.cudagraphs,
|
||||
graph_id=compiler_config_extra.graph_id,
|
||||
is_inference=is_inference,
|
||||
boxed_forward_device_index=compiler_config_extra.forward_device,
|
||||
)
|
||||
|
||||
|
||||
def compile_fx_backward(
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
compiler_config_extra: CompilerConfigExtra,
|
||||
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
||||
) -> OutputCode:
|
||||
"""
|
||||
Compile the backward graph of the given graph module.
|
||||
|
||||
Args:
|
||||
gm: The graph module to compile.
|
||||
example_inputs: The example inputs to use for compilation.
|
||||
compiler_config_extra: Extra configuration for the compiler.
|
||||
inner_compile: The inner compile function to use.
|
||||
"""
|
||||
from torch._dynamo.convert_frame import compile_lock
|
||||
|
||||
with compile_lock:
|
||||
model_outputs_node = output_node(gm)
|
||||
if config.bw_outputs_user_visible:
|
||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = [
|
||||
idx
|
||||
for idx, n in enumerate(model_outputs)
|
||||
if isinstance(n, torch.fx.Node)
|
||||
]
|
||||
else:
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||
|
||||
fixed = count_tangents(gm)
|
||||
with (
|
||||
config.patch(get_cpp_wrapper_config())
|
||||
if config.cpp_wrapper
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
return inner_compile(
|
||||
gm,
|
||||
example_inputs,
|
||||
static_input_idxs=list(range(fixed)),
|
||||
cudagraphs=compiler_config_extra.cudagraphs,
|
||||
is_backward=True,
|
||||
graph_id=compiler_config_extra.graph_id,
|
||||
boxed_forward_device_index=compiler_config_extra.forward_device,
|
||||
)
|
||||
|
||||
|
||||
def run_pre_grad_passes(
|
||||
model_: GraphModule, example_inputs_: Sequence[InputType]
|
||||
) -> GraphModule:
|
||||
# "before_pre_grad_graph" is used in inductor provenance
|
||||
# tracking highlighter front-end.
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "before_pre_grad_graph",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: model_.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
)
|
||||
+ f"\n\n # graph id: {id(model_.graph)}",
|
||||
)
|
||||
pre_grad_graphs_log.debug(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"BEFORE PRE GRAD",
|
||||
model_,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
colored=True,
|
||||
),
|
||||
)
|
||||
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
|
||||
|
||||
if config.trace.provenance_tracking_level == 1:
|
||||
for node in model_.graph.nodes:
|
||||
if node.stack_trace:
|
||||
torch._inductor.debug._inductor_pre_grad_node_stack_trace[node.name] = (
|
||||
node.stack_trace
|
||||
)
|
||||
|
||||
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "after_pre_grad_graph",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: model_.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
)
|
||||
+ f"\n\n # graph id: {id(model_.graph)}",
|
||||
)
|
||||
return model_
|
||||
|
||||
|
||||
def compile_fx(
|
||||
model_: GraphModule,
|
||||
example_inputs_: Sequence[InputType],
|
||||
|
|
@ -2264,50 +2508,7 @@ def compile_fx(
|
|||
# having AOTAutograd trace it.
|
||||
# TODO: Get rid of this?
|
||||
if isinstance(model_, GraphModule):
|
||||
# "before_pre_grad_graph" is used in inductor provenance
|
||||
# tracking highlighter front-end.
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "before_pre_grad_graph",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: model_.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
)
|
||||
+ f"\n\n # graph id: {id(model_.graph)}",
|
||||
)
|
||||
pre_grad_graphs_log.debug(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"BEFORE PRE GRAD",
|
||||
model_,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
colored=True,
|
||||
),
|
||||
)
|
||||
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
|
||||
|
||||
if config.trace.provenance_tracking_level == 1:
|
||||
for node in model_.graph.nodes:
|
||||
if node.stack_trace:
|
||||
torch._inductor.debug._inductor_pre_grad_node_stack_trace[
|
||||
node.name
|
||||
] = node.stack_trace
|
||||
|
||||
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "after_pre_grad_graph",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: model_.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
)
|
||||
+ f"\n\n # graph id: {id(model_.graph)}",
|
||||
)
|
||||
model_ = run_pre_grad_passes(model_, example_inputs_)
|
||||
|
||||
# TODO: Move this before recursive pre-grad passes
|
||||
# NB: This short circuit never occurs for Dynamo produced graphs
|
||||
|
|
@ -2323,20 +2524,7 @@ def compile_fx(
|
|||
|
||||
num_example_inputs = len(example_inputs_)
|
||||
|
||||
# Although cudagraphs may have been enabled via config, various
|
||||
# conditions (which are tested within the bowels of Inductor) may
|
||||
# force cudagraphs to be disabled. This mutable box lets us retrieve
|
||||
# the final determination if cudagraphs actually can be used or not.
|
||||
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
forward_device = BoxedDeviceIndex(None)
|
||||
|
||||
# TODO: The modern style is to use CompileId from TracingContext to
|
||||
# identify Inductor compilation. However, this CompileId cannot
|
||||
# uniquely identify multiple Inductor compilations that arise from
|
||||
# DDPOptimizer
|
||||
graph_id = next(_graph_counter)
|
||||
compiler_config_extra = create_compiler_config_extra(config)
|
||||
|
||||
decompositions = (
|
||||
decompositions if decompositions is not None else select_decomp_table()
|
||||
|
|
@ -2348,105 +2536,18 @@ def compile_fx(
|
|||
is_inference: bool,
|
||||
) -> OutputCode:
|
||||
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
|
||||
if is_inference:
|
||||
# partition_fn won't be called
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "before_joint_graph",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: gm.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
_recursive_joint_graph_passes(gm)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "after_joint_graph",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: gm.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
||||
num_example_inputs, len(example_inputs)
|
||||
)
|
||||
|
||||
model_outputs_node = output_node(gm)
|
||||
if config.keep_output_stride:
|
||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||
num_model_outputs = len(model_outputs)
|
||||
|
||||
context = torch._guards.TracingContext.try_get()
|
||||
# See Note [User Outputs in the inductor graph]
|
||||
if context is not None and context.fw_metadata and not is_inference:
|
||||
original_output_start_index = (
|
||||
context.fw_metadata.num_mutated_inp_runtime_indices
|
||||
)
|
||||
else:
|
||||
original_output_start_index = 0
|
||||
|
||||
if isinstance(model_, GraphModule):
|
||||
*_, orig_model_outputs_node = model_.graph.nodes
|
||||
assert orig_model_outputs_node.op == "output"
|
||||
orig_model_outputs, _ = pytree.tree_flatten(
|
||||
orig_model_outputs_node.args
|
||||
)
|
||||
num_orig_model_outputs = len(orig_model_outputs)
|
||||
else:
|
||||
num_orig_model_outputs = num_model_outputs
|
||||
|
||||
assert num_orig_model_outputs <= num_model_outputs
|
||||
|
||||
# Note [User Outputs in the inductor graph]
|
||||
# We makes the following assumption
|
||||
# For inference
|
||||
# len(orig_model_outputs) == len(model_outputs)
|
||||
# For training
|
||||
# len(orig_model_outputs) <= len(model_outputs)
|
||||
# During training, most of the time the model_outputs starts with
|
||||
# original module's outputs followed by saved activations.
|
||||
# But this can be not true if the model have inplace updated tensors.
|
||||
# AOTAutograd will make those tensors being returned before the original
|
||||
# module's output.
|
||||
# To make things safe, we'll use original_output_start_index field
|
||||
# set by AOTAutograd to decide where the original module outputs start.
|
||||
orig_output_end_idx = (
|
||||
original_output_start_index + num_orig_model_outputs
|
||||
)
|
||||
# Sanity check: we are about to splice out the "user" outputs from the full set
|
||||
# of "graph" outputs. Make sure we're within bounds.
|
||||
assert orig_output_end_idx <= num_model_outputs
|
||||
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = [
|
||||
idx
|
||||
for idx in range(
|
||||
original_output_start_index, orig_output_end_idx
|
||||
)
|
||||
if isinstance(model_outputs[idx], torch.fx.Node)
|
||||
]
|
||||
if isinstance(model_, GraphModule):
|
||||
num_orig_model_outputs = get_num_model_outputs(model_)
|
||||
else:
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||
|
||||
# We also mark the invoke_subgraph outputs as user_visible to
|
||||
# force the outputs of invoke_subgraph subgraph to follow the
|
||||
# original strides
|
||||
_recursive_record_user_visible_output_idxs(gm)
|
||||
|
||||
return inner_compile(
|
||||
num_orig_model_outputs = get_num_model_outputs(gm)
|
||||
return compile_fx_forward(
|
||||
gm,
|
||||
example_inputs,
|
||||
static_input_idxs=get_static_input_idxs(fixed),
|
||||
cudagraphs=cudagraphs,
|
||||
graph_id=graph_id,
|
||||
num_orig_model_outputs=num_orig_model_outputs,
|
||||
num_example_inputs=num_example_inputs,
|
||||
compiler_config_extra=compiler_config_extra,
|
||||
inner_compile=inner_compile,
|
||||
is_inference=is_inference,
|
||||
boxed_forward_device_index=forward_device,
|
||||
)
|
||||
|
||||
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
|
||||
|
|
@ -2460,9 +2561,9 @@ def compile_fx(
|
|||
dynamo_model=model_,
|
||||
num_example_inputs=num_example_inputs,
|
||||
inner_compile=inner_compile,
|
||||
cudagraphs=cudagraphs,
|
||||
graph_id=graph_id,
|
||||
forward_device=forward_device,
|
||||
cudagraphs=compiler_config_extra.cudagraphs,
|
||||
graph_id=compiler_config_extra.graph_id,
|
||||
forward_device=compiler_config_extra.forward_device,
|
||||
)
|
||||
else:
|
||||
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
||||
|
|
@ -2474,38 +2575,15 @@ def compile_fx(
|
|||
def bw_compiler(
|
||||
gm: GraphModule, example_inputs: Sequence[InputType]
|
||||
) -> OutputCode:
|
||||
from torch._dynamo.convert_frame import compile_lock
|
||||
|
||||
with (
|
||||
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
|
||||
compile_lock,
|
||||
):
|
||||
model_outputs_node = output_node(gm)
|
||||
if config.bw_outputs_user_visible:
|
||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = [
|
||||
idx
|
||||
for idx, n in enumerate(model_outputs)
|
||||
if isinstance(n, torch.fx.Node)
|
||||
]
|
||||
else:
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||
|
||||
fixed = count_tangents(gm)
|
||||
with (
|
||||
config.patch(get_cpp_wrapper_config())
|
||||
if config.cpp_wrapper
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
return inner_compile(
|
||||
gm,
|
||||
example_inputs,
|
||||
static_input_idxs=list(range(fixed)),
|
||||
cudagraphs=cudagraphs,
|
||||
is_backward=True,
|
||||
graph_id=graph_id,
|
||||
boxed_forward_device_index=forward_device,
|
||||
)
|
||||
return compile_fx_backward(
|
||||
gm,
|
||||
example_inputs,
|
||||
compiler_config_extra=compiler_config_extra,
|
||||
inner_compile=inner_compile,
|
||||
)
|
||||
|
||||
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
|
||||
|
||||
|
|
@ -2592,8 +2670,8 @@ def compile_fx(
|
|||
decompositions=decompositions,
|
||||
partition_fn=partition_fn,
|
||||
keep_inference_input_mutations=True,
|
||||
cudagraphs=cudagraphs,
|
||||
boxed_forward_device_index=forward_device,
|
||||
cudagraphs=compiler_config_extra.cudagraphs,
|
||||
boxed_forward_device_index=compiler_config_extra.forward_device,
|
||||
ignore_shape_env=ignore_shape_env,
|
||||
)(model_, example_inputs_)
|
||||
except ShortenTraceback as e:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user