mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Lift fw_compiler and bw_compiler as toplevel functions. (#161762)
This is a no-op refactor to compiler_fx which lifts the logic of fw_compiler and bw_compiler to toplevel, so that they can be reused in a different stack (e.g. precompile). Differential Revision: [D81292968](https://our.internmc.facebook.com/intern/diff/D81292968/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161762 Approved by: https://github.com/angelayi, https://github.com/yushangdi
This commit is contained in:
parent
05eeb29976
commit
eb78757708
|
|
@ -154,6 +154,8 @@ else:
|
||||||
from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
|
from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
import types
|
||||||
|
|
||||||
from torch._functorch._aot_autograd.schemas import (
|
from torch._functorch._aot_autograd.schemas import (
|
||||||
FQN,
|
FQN,
|
||||||
GraphInputName,
|
GraphInputName,
|
||||||
|
|
@ -2120,6 +2122,248 @@ def partition_fn(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_model_outputs(model: GraphModule) -> int:
|
||||||
|
model_outputs_node = output_node(model)
|
||||||
|
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||||
|
return len(model_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CompilerConfigExtra:
|
||||||
|
cudagraphs: BoxedBool
|
||||||
|
graph_id: int
|
||||||
|
forward_device: BoxedDeviceIndex
|
||||||
|
|
||||||
|
|
||||||
|
def create_compiler_config_extra(config: types.ModuleType) -> CompilerConfigExtra:
|
||||||
|
# Although cudagraphs may have been enabled via config, various
|
||||||
|
# conditions (which are tested within the bowels of Inductor) may
|
||||||
|
# force cudagraphs to be disabled. This mutable box lets us retrieve
|
||||||
|
# the final determination if cudagraphs actually can be used or not.
|
||||||
|
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
||||||
|
|
||||||
|
# TODO: The modern style is to use CompileId from TracingContext to
|
||||||
|
# identify Inductor compilation. However, this CompileId cannot
|
||||||
|
# uniquely identify multiple Inductor compilations that arise from
|
||||||
|
# DDPOptimizer
|
||||||
|
graph_id = next(_graph_counter)
|
||||||
|
|
||||||
|
# See [Backward Generation Handling]
|
||||||
|
forward_device = BoxedDeviceIndex(None)
|
||||||
|
|
||||||
|
return CompilerConfigExtra(
|
||||||
|
cudagraphs=cudagraphs,
|
||||||
|
graph_id=graph_id,
|
||||||
|
forward_device=forward_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_fx_forward(
|
||||||
|
gm: GraphModule,
|
||||||
|
example_inputs: Sequence[InputType],
|
||||||
|
num_orig_model_outputs: int,
|
||||||
|
num_example_inputs: int,
|
||||||
|
compiler_config_extra: CompilerConfigExtra,
|
||||||
|
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
||||||
|
is_inference: bool = False,
|
||||||
|
) -> OutputCode:
|
||||||
|
"""
|
||||||
|
Compile the forward graph of the given graph module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm: The graph module to compile.
|
||||||
|
example_inputs: The example inputs to use for compilation.
|
||||||
|
num_orig_model_outputs: The number of model outputs from the original dynamo graph.
|
||||||
|
num_example_inputs: The number of example inputs from the original dynamo graph.
|
||||||
|
compiler_config_extra: Extra configuration for the compiler.
|
||||||
|
inner_compile: The inner compile function to use.
|
||||||
|
is_inference: Whether this is an inference graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if is_inference:
|
||||||
|
# partition_fn won't be called
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "before_joint_graph",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: gm.print_readable(
|
||||||
|
print_output=False, include_stride=True, include_device=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
_recursive_joint_graph_passes(gm)
|
||||||
|
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "after_joint_graph",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: gm.print_readable(
|
||||||
|
print_output=False, include_stride=True, include_device=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
||||||
|
num_example_inputs, len(example_inputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_outputs_node = output_node(gm)
|
||||||
|
if config.keep_output_stride:
|
||||||
|
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||||
|
num_model_outputs = len(model_outputs)
|
||||||
|
|
||||||
|
context = torch._guards.TracingContext.try_get()
|
||||||
|
# See Note [User Outputs in the inductor graph]
|
||||||
|
if context is not None and context.fw_metadata and not is_inference:
|
||||||
|
original_output_start_index = (
|
||||||
|
context.fw_metadata.num_mutated_inp_runtime_indices
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
original_output_start_index = 0
|
||||||
|
|
||||||
|
assert num_orig_model_outputs <= num_model_outputs
|
||||||
|
|
||||||
|
# Note [User Outputs in the inductor graph]
|
||||||
|
# We makes the following assumption
|
||||||
|
# For inference
|
||||||
|
# len(orig_model_outputs) == len(model_outputs)
|
||||||
|
# For training
|
||||||
|
# len(orig_model_outputs) <= len(model_outputs)
|
||||||
|
# During training, most of the time the model_outputs starts with
|
||||||
|
# original module's outputs followed by saved activations.
|
||||||
|
# But this can be not true if the model have inplace updated tensors.
|
||||||
|
# AOTAutograd will make those tensors being returned before the original
|
||||||
|
# module's output.
|
||||||
|
# To make things safe, we'll use original_output_start_index field
|
||||||
|
# set by AOTAutograd to decide where the original module outputs start.
|
||||||
|
orig_output_end_idx = original_output_start_index + num_orig_model_outputs
|
||||||
|
# Sanity check: we are about to splice out the "user" outputs from the full set
|
||||||
|
# of "graph" outputs. Make sure we're within bounds.
|
||||||
|
assert orig_output_end_idx <= num_model_outputs
|
||||||
|
|
||||||
|
model_outputs_node.meta["user_visible_output_idxs"] = [
|
||||||
|
idx
|
||||||
|
for idx in range(original_output_start_index, orig_output_end_idx)
|
||||||
|
if isinstance(model_outputs[idx], torch.fx.Node)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||||
|
|
||||||
|
# We also mark the invoke_subgraph outputs as user_visible to
|
||||||
|
# force the outputs of invoke_subgraph subgraph to follow the
|
||||||
|
# original strides
|
||||||
|
_recursive_record_user_visible_output_idxs(gm)
|
||||||
|
|
||||||
|
return inner_compile(
|
||||||
|
gm,
|
||||||
|
example_inputs,
|
||||||
|
static_input_idxs=get_static_input_idxs(fixed),
|
||||||
|
cudagraphs=compiler_config_extra.cudagraphs,
|
||||||
|
graph_id=compiler_config_extra.graph_id,
|
||||||
|
is_inference=is_inference,
|
||||||
|
boxed_forward_device_index=compiler_config_extra.forward_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_fx_backward(
|
||||||
|
gm: GraphModule,
|
||||||
|
example_inputs: Sequence[InputType],
|
||||||
|
compiler_config_extra: CompilerConfigExtra,
|
||||||
|
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
||||||
|
) -> OutputCode:
|
||||||
|
"""
|
||||||
|
Compile the backward graph of the given graph module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm: The graph module to compile.
|
||||||
|
example_inputs: The example inputs to use for compilation.
|
||||||
|
compiler_config_extra: Extra configuration for the compiler.
|
||||||
|
inner_compile: The inner compile function to use.
|
||||||
|
"""
|
||||||
|
from torch._dynamo.convert_frame import compile_lock
|
||||||
|
|
||||||
|
with compile_lock:
|
||||||
|
model_outputs_node = output_node(gm)
|
||||||
|
if config.bw_outputs_user_visible:
|
||||||
|
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||||
|
model_outputs_node.meta["user_visible_output_idxs"] = [
|
||||||
|
idx
|
||||||
|
for idx, n in enumerate(model_outputs)
|
||||||
|
if isinstance(n, torch.fx.Node)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||||
|
|
||||||
|
fixed = count_tangents(gm)
|
||||||
|
with (
|
||||||
|
config.patch(get_cpp_wrapper_config())
|
||||||
|
if config.cpp_wrapper
|
||||||
|
else contextlib.nullcontext()
|
||||||
|
):
|
||||||
|
return inner_compile(
|
||||||
|
gm,
|
||||||
|
example_inputs,
|
||||||
|
static_input_idxs=list(range(fixed)),
|
||||||
|
cudagraphs=compiler_config_extra.cudagraphs,
|
||||||
|
is_backward=True,
|
||||||
|
graph_id=compiler_config_extra.graph_id,
|
||||||
|
boxed_forward_device_index=compiler_config_extra.forward_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_pre_grad_passes(
|
||||||
|
model_: GraphModule, example_inputs_: Sequence[InputType]
|
||||||
|
) -> GraphModule:
|
||||||
|
# "before_pre_grad_graph" is used in inductor provenance
|
||||||
|
# tracking highlighter front-end.
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "before_pre_grad_graph",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: model_.print_readable(
|
||||||
|
print_output=False, include_stride=True, include_device=True
|
||||||
|
)
|
||||||
|
+ f"\n\n # graph id: {id(model_.graph)}",
|
||||||
|
)
|
||||||
|
pre_grad_graphs_log.debug(
|
||||||
|
"%s",
|
||||||
|
lazy_format_graph_code(
|
||||||
|
"BEFORE PRE GRAD",
|
||||||
|
model_,
|
||||||
|
include_stride=True,
|
||||||
|
include_device=True,
|
||||||
|
colored=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
|
||||||
|
|
||||||
|
if config.trace.provenance_tracking_level == 1:
|
||||||
|
for node in model_.graph.nodes:
|
||||||
|
if node.stack_trace:
|
||||||
|
torch._inductor.debug._inductor_pre_grad_node_stack_trace[node.name] = (
|
||||||
|
node.stack_trace
|
||||||
|
)
|
||||||
|
|
||||||
|
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "after_pre_grad_graph",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: model_.print_readable(
|
||||||
|
print_output=False, include_stride=True, include_device=True
|
||||||
|
)
|
||||||
|
+ f"\n\n # graph id: {id(model_.graph)}",
|
||||||
|
)
|
||||||
|
return model_
|
||||||
|
|
||||||
|
|
||||||
def compile_fx(
|
def compile_fx(
|
||||||
model_: GraphModule,
|
model_: GraphModule,
|
||||||
example_inputs_: Sequence[InputType],
|
example_inputs_: Sequence[InputType],
|
||||||
|
|
@ -2264,50 +2508,7 @@ def compile_fx(
|
||||||
# having AOTAutograd trace it.
|
# having AOTAutograd trace it.
|
||||||
# TODO: Get rid of this?
|
# TODO: Get rid of this?
|
||||||
if isinstance(model_, GraphModule):
|
if isinstance(model_, GraphModule):
|
||||||
# "before_pre_grad_graph" is used in inductor provenance
|
model_ = run_pre_grad_passes(model_, example_inputs_)
|
||||||
# tracking highlighter front-end.
|
|
||||||
trace_structured(
|
|
||||||
"artifact",
|
|
||||||
metadata_fn=lambda: {
|
|
||||||
"name": "before_pre_grad_graph",
|
|
||||||
"encoding": "string",
|
|
||||||
},
|
|
||||||
payload_fn=lambda: model_.print_readable(
|
|
||||||
print_output=False, include_stride=True, include_device=True
|
|
||||||
)
|
|
||||||
+ f"\n\n # graph id: {id(model_.graph)}",
|
|
||||||
)
|
|
||||||
pre_grad_graphs_log.debug(
|
|
||||||
"%s",
|
|
||||||
lazy_format_graph_code(
|
|
||||||
"BEFORE PRE GRAD",
|
|
||||||
model_,
|
|
||||||
include_stride=True,
|
|
||||||
include_device=True,
|
|
||||||
colored=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
|
|
||||||
|
|
||||||
if config.trace.provenance_tracking_level == 1:
|
|
||||||
for node in model_.graph.nodes:
|
|
||||||
if node.stack_trace:
|
|
||||||
torch._inductor.debug._inductor_pre_grad_node_stack_trace[
|
|
||||||
node.name
|
|
||||||
] = node.stack_trace
|
|
||||||
|
|
||||||
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
|
||||||
trace_structured(
|
|
||||||
"artifact",
|
|
||||||
metadata_fn=lambda: {
|
|
||||||
"name": "after_pre_grad_graph",
|
|
||||||
"encoding": "string",
|
|
||||||
},
|
|
||||||
payload_fn=lambda: model_.print_readable(
|
|
||||||
print_output=False, include_stride=True, include_device=True
|
|
||||||
)
|
|
||||||
+ f"\n\n # graph id: {id(model_.graph)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Move this before recursive pre-grad passes
|
# TODO: Move this before recursive pre-grad passes
|
||||||
# NB: This short circuit never occurs for Dynamo produced graphs
|
# NB: This short circuit never occurs for Dynamo produced graphs
|
||||||
|
|
@ -2323,20 +2524,7 @@ def compile_fx(
|
||||||
|
|
||||||
num_example_inputs = len(example_inputs_)
|
num_example_inputs = len(example_inputs_)
|
||||||
|
|
||||||
# Although cudagraphs may have been enabled via config, various
|
compiler_config_extra = create_compiler_config_extra(config)
|
||||||
# conditions (which are tested within the bowels of Inductor) may
|
|
||||||
# force cudagraphs to be disabled. This mutable box lets us retrieve
|
|
||||||
# the final determination if cudagraphs actually can be used or not.
|
|
||||||
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
|
||||||
|
|
||||||
# See [Backward Generation Handling]
|
|
||||||
forward_device = BoxedDeviceIndex(None)
|
|
||||||
|
|
||||||
# TODO: The modern style is to use CompileId from TracingContext to
|
|
||||||
# identify Inductor compilation. However, this CompileId cannot
|
|
||||||
# uniquely identify multiple Inductor compilations that arise from
|
|
||||||
# DDPOptimizer
|
|
||||||
graph_id = next(_graph_counter)
|
|
||||||
|
|
||||||
decompositions = (
|
decompositions = (
|
||||||
decompositions if decompositions is not None else select_decomp_table()
|
decompositions if decompositions is not None else select_decomp_table()
|
||||||
|
|
@ -2348,105 +2536,18 @@ def compile_fx(
|
||||||
is_inference: bool,
|
is_inference: bool,
|
||||||
) -> OutputCode:
|
) -> OutputCode:
|
||||||
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
|
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
|
||||||
if is_inference:
|
if isinstance(model_, GraphModule):
|
||||||
# partition_fn won't be called
|
num_orig_model_outputs = get_num_model_outputs(model_)
|
||||||
trace_structured(
|
|
||||||
"artifact",
|
|
||||||
metadata_fn=lambda: {
|
|
||||||
"name": "before_joint_graph",
|
|
||||||
"encoding": "string",
|
|
||||||
},
|
|
||||||
payload_fn=lambda: gm.print_readable(
|
|
||||||
print_output=False, include_stride=True, include_device=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
_recursive_joint_graph_passes(gm)
|
|
||||||
|
|
||||||
trace_structured(
|
|
||||||
"artifact",
|
|
||||||
metadata_fn=lambda: {
|
|
||||||
"name": "after_joint_graph",
|
|
||||||
"encoding": "string",
|
|
||||||
},
|
|
||||||
payload_fn=lambda: gm.print_readable(
|
|
||||||
print_output=False, include_stride=True, include_device=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
|
||||||
num_example_inputs, len(example_inputs)
|
|
||||||
)
|
|
||||||
|
|
||||||
model_outputs_node = output_node(gm)
|
|
||||||
if config.keep_output_stride:
|
|
||||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
|
||||||
num_model_outputs = len(model_outputs)
|
|
||||||
|
|
||||||
context = torch._guards.TracingContext.try_get()
|
|
||||||
# See Note [User Outputs in the inductor graph]
|
|
||||||
if context is not None and context.fw_metadata and not is_inference:
|
|
||||||
original_output_start_index = (
|
|
||||||
context.fw_metadata.num_mutated_inp_runtime_indices
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
original_output_start_index = 0
|
|
||||||
|
|
||||||
if isinstance(model_, GraphModule):
|
|
||||||
*_, orig_model_outputs_node = model_.graph.nodes
|
|
||||||
assert orig_model_outputs_node.op == "output"
|
|
||||||
orig_model_outputs, _ = pytree.tree_flatten(
|
|
||||||
orig_model_outputs_node.args
|
|
||||||
)
|
|
||||||
num_orig_model_outputs = len(orig_model_outputs)
|
|
||||||
else:
|
|
||||||
num_orig_model_outputs = num_model_outputs
|
|
||||||
|
|
||||||
assert num_orig_model_outputs <= num_model_outputs
|
|
||||||
|
|
||||||
# Note [User Outputs in the inductor graph]
|
|
||||||
# We makes the following assumption
|
|
||||||
# For inference
|
|
||||||
# len(orig_model_outputs) == len(model_outputs)
|
|
||||||
# For training
|
|
||||||
# len(orig_model_outputs) <= len(model_outputs)
|
|
||||||
# During training, most of the time the model_outputs starts with
|
|
||||||
# original module's outputs followed by saved activations.
|
|
||||||
# But this can be not true if the model have inplace updated tensors.
|
|
||||||
# AOTAutograd will make those tensors being returned before the original
|
|
||||||
# module's output.
|
|
||||||
# To make things safe, we'll use original_output_start_index field
|
|
||||||
# set by AOTAutograd to decide where the original module outputs start.
|
|
||||||
orig_output_end_idx = (
|
|
||||||
original_output_start_index + num_orig_model_outputs
|
|
||||||
)
|
|
||||||
# Sanity check: we are about to splice out the "user" outputs from the full set
|
|
||||||
# of "graph" outputs. Make sure we're within bounds.
|
|
||||||
assert orig_output_end_idx <= num_model_outputs
|
|
||||||
|
|
||||||
model_outputs_node.meta["user_visible_output_idxs"] = [
|
|
||||||
idx
|
|
||||||
for idx in range(
|
|
||||||
original_output_start_index, orig_output_end_idx
|
|
||||||
)
|
|
||||||
if isinstance(model_outputs[idx], torch.fx.Node)
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
num_orig_model_outputs = get_num_model_outputs(gm)
|
||||||
|
return compile_fx_forward(
|
||||||
# We also mark the invoke_subgraph outputs as user_visible to
|
|
||||||
# force the outputs of invoke_subgraph subgraph to follow the
|
|
||||||
# original strides
|
|
||||||
_recursive_record_user_visible_output_idxs(gm)
|
|
||||||
|
|
||||||
return inner_compile(
|
|
||||||
gm,
|
gm,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
static_input_idxs=get_static_input_idxs(fixed),
|
num_orig_model_outputs=num_orig_model_outputs,
|
||||||
cudagraphs=cudagraphs,
|
num_example_inputs=num_example_inputs,
|
||||||
graph_id=graph_id,
|
compiler_config_extra=compiler_config_extra,
|
||||||
|
inner_compile=inner_compile,
|
||||||
is_inference=is_inference,
|
is_inference=is_inference,
|
||||||
boxed_forward_device_index=forward_device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
|
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
|
||||||
|
|
@ -2460,9 +2561,9 @@ def compile_fx(
|
||||||
dynamo_model=model_,
|
dynamo_model=model_,
|
||||||
num_example_inputs=num_example_inputs,
|
num_example_inputs=num_example_inputs,
|
||||||
inner_compile=inner_compile,
|
inner_compile=inner_compile,
|
||||||
cudagraphs=cudagraphs,
|
cudagraphs=compiler_config_extra.cudagraphs,
|
||||||
graph_id=graph_id,
|
graph_id=compiler_config_extra.graph_id,
|
||||||
forward_device=forward_device,
|
forward_device=compiler_config_extra.forward_device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
||||||
|
|
@ -2474,38 +2575,15 @@ def compile_fx(
|
||||||
def bw_compiler(
|
def bw_compiler(
|
||||||
gm: GraphModule, example_inputs: Sequence[InputType]
|
gm: GraphModule, example_inputs: Sequence[InputType]
|
||||||
) -> OutputCode:
|
) -> OutputCode:
|
||||||
from torch._dynamo.convert_frame import compile_lock
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
|
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
|
||||||
compile_lock,
|
|
||||||
):
|
):
|
||||||
model_outputs_node = output_node(gm)
|
return compile_fx_backward(
|
||||||
if config.bw_outputs_user_visible:
|
gm,
|
||||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
example_inputs,
|
||||||
model_outputs_node.meta["user_visible_output_idxs"] = [
|
compiler_config_extra=compiler_config_extra,
|
||||||
idx
|
inner_compile=inner_compile,
|
||||||
for idx, n in enumerate(model_outputs)
|
)
|
||||||
if isinstance(n, torch.fx.Node)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
|
||||||
|
|
||||||
fixed = count_tangents(gm)
|
|
||||||
with (
|
|
||||||
config.patch(get_cpp_wrapper_config())
|
|
||||||
if config.cpp_wrapper
|
|
||||||
else contextlib.nullcontext()
|
|
||||||
):
|
|
||||||
return inner_compile(
|
|
||||||
gm,
|
|
||||||
example_inputs,
|
|
||||||
static_input_idxs=list(range(fixed)),
|
|
||||||
cudagraphs=cudagraphs,
|
|
||||||
is_backward=True,
|
|
||||||
graph_id=graph_id,
|
|
||||||
boxed_forward_device_index=forward_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
|
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
|
||||||
|
|
||||||
|
|
@ -2592,8 +2670,8 @@ def compile_fx(
|
||||||
decompositions=decompositions,
|
decompositions=decompositions,
|
||||||
partition_fn=partition_fn,
|
partition_fn=partition_fn,
|
||||||
keep_inference_input_mutations=True,
|
keep_inference_input_mutations=True,
|
||||||
cudagraphs=cudagraphs,
|
cudagraphs=compiler_config_extra.cudagraphs,
|
||||||
boxed_forward_device_index=forward_device,
|
boxed_forward_device_index=compiler_config_extra.forward_device,
|
||||||
ignore_shape_env=ignore_shape_env,
|
ignore_shape_env=ignore_shape_env,
|
||||||
)(model_, example_inputs_)
|
)(model_, example_inputs_)
|
||||||
except ShortenTraceback as e:
|
except ShortenTraceback as e:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user