pytorch/torch/_dynamo/functional_export.py
Tugsbayasgalan Manlaibaatar 6d65737aee testing infra and some fixes (#162183)
This PR is quite large in that it covers most of rough edges in the new strict export flow:

1. Handle nn_module_stack correctly now that we are tracing wrapper module
2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore.
3. Correct input and output handling.

@diff-train-skip-merge

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183
Approved by: https://github.com/zhxchen17
2025-09-10 20:48:12 +00:00

451 lines
18 KiB
Python

import builtins
import inspect
import logging
import traceback
from collections import namedtuple
from typing import Any, Callable, Optional, Union
import sympy
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
from torch._dynamo.eval_frame import argument_names
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
from torch.fx import Node
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
StatelessSymbolicContext,
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
log = logging.getLogger(__name__)
def clean_nn_module_stack(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in graph_module.graph.nodes:
if "nn_module_stack" in node.meta:
nn_module_stack = node.meta["nn_module_stack"].copy()
first_key = next(iter(nn_module_stack.keys()))
if "export_root" in first_key:
del nn_module_stack[first_key]
nn_module_stack_corrected = {}
for k, v in nn_module_stack.items():
k_new = "".join(k.split("__export_root"))
child_name, child_class = v
child_name = child_name.replace("._export_root", "")
nn_module_stack_corrected[k_new] = (child_name, child_class)
node.meta["nn_module_stack"] = nn_module_stack_corrected
return graph_module
def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
"""Remove export_root artifacts from FX graph in-place"""
# Clean parameter names: L__self____export_root_param -> L__self___param
def clean_name(name) -> str:
return name.replace("__export_root_", "_") if "__export_root_" in name else name
# Update get_attr nodes in-place
for node in graph_module.graph.nodes:
if node.op == "get_attr":
old_target = node.target
new_target = clean_name(old_target)
if new_target != old_target:
node.target = new_target
# Move the parameter to the new name
if hasattr(graph_module, old_target):
param = torch.fx.graph_module._get_attr(graph_module, old_target)
torch.fx.graph_module._set_attr(graph_module, new_target, param)
torch.fx.graph_module._del_attr(graph_module, old_target)
class ModuleToTrace(torch.nn.Module):
def __init__(self, foo: Any, in_spec: Any) -> None:
super().__init__()
self._export_root = foo
self.in_spec = in_spec
def forward(self, *flat_args: Any) -> "ExportTracerOutput":
args, kwargs = pytree.tree_unflatten(flat_args, self.in_spec)
res = self._export_root(*args, **kwargs)
out_flat, out_spec = pytree.tree_flatten(res)
return ExportTracerOutput(out_flat, out_spec)
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
# mypy: disable-error-code="no-untyped-def,var-annotated,assignment,index,operator"
class DynamoGraphTransformer(torch.fx.Transformer):
"""Graph transformer for dynamo export that flattens inputs/outputs without complex matching."""
def __init__(
self,
module: torch.fx.GraphModule,
flat_inputs: list[Any],
flat_args_dynamic_dims: list[set[int]],
graph_input_order: dict[int, int],
graph_output_map: dict[int, tuple[str, Any]],
fake_mode: Optional[Any] = None,
) -> None:
super().__init__(module)
assert len(flat_args_dynamic_dims) == len(flat_inputs)
self.flat_inputs = flat_inputs
self.flat_args_dynamic_dims = flat_args_dynamic_dims
self.graph_input_order = graph_input_order
self.graph_output_map = graph_output_map
self.fake_mode = fake_mode
# Get original placeholders and output
self.placeholders = [n for n in module.graph.nodes if n.op == "placeholder"]
self.output_node = next(n for n in module.graph.nodes if n.op == "output")
# Create new flattened input placeholders
self.new_input_nodes: dict[int, torch.fx.Node] = {}
self._create_flattened_inputs()
# Iterator for replacing old placeholders
self.old_to_new_mapping = {}
self._create_placeholder_mapping()
def _create_flattened_inputs(self) -> None:
"""Create new placeholder nodes for flattened inputs with proper fake tensors."""
for i in range(len(self.flat_inputs)):
placeholder = super().placeholder(f"arg_{i}", (), {})
# Check if this user input (index i) maps to a graph placeholder
if i in self.graph_input_order:
# graph_input_order[i] gives us which graph placeholder this user input corresponds to
graph_placeholder_idx = self.graph_input_order[i]
if graph_placeholder_idx < len(self.placeholders):
orig_placeholder = self.placeholders[graph_placeholder_idx]
# Copy other metadata but not "val" yet
for key, value in orig_placeholder.meta.items():
if key != "val":
placeholder.node.meta[key] = value
# Always ensure we have proper "val" metadata from fake tensor
if self.fake_mode is not None and isinstance(
self.flat_inputs[i], torch.Tensor
):
placeholder.node.meta["val"] = self.fake_mode.from_tensor(
self.flat_inputs[i],
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[
(
DimDynamic.DYNAMIC
if d in self.flat_args_dynamic_dims[i]
else DimDynamic.STATIC
)
for d in range(len(self.flat_inputs[i].shape))
],
constraint_sizes=[None] * len(self.flat_inputs[i].shape),
),
)
elif hasattr(self.flat_inputs[i], "val"): # _IntWrapper case
placeholder.node.meta["val"] = self.flat_inputs[i].val
else:
placeholder.node.meta["val"] = self.flat_inputs[i]
self.new_input_nodes[i] = placeholder
def _create_placeholder_mapping(self) -> None:
"""Create mapping from old placeholders to new ones."""
# graph_input_order maps: user_input_index -> graph_placeholder_index
# We need to create: old_graph_placeholder -> new_user_input_placeholder
for user_input_idx, graph_placeholder_idx in self.graph_input_order.items():
if graph_placeholder_idx < len(self.placeholders):
old_placeholder = self.placeholders[graph_placeholder_idx]
new_placeholder = self.new_input_nodes[user_input_idx]
self.old_to_new_mapping[old_placeholder] = new_placeholder
def placeholder(self, target, args, kwargs) -> Any:
"""Replace old placeholders with new flattened ones."""
# Return the corresponding new placeholder
if self.current_node in self.old_to_new_mapping:
new_arg = self.old_to_new_mapping[self.current_node]
# Copy over additional metadata from current node, but don't overwrite "val"
for key in ["tensor_dict", "example_value", "unbacked_bindings"]:
if key in self.current_node.meta:
new_arg.node.meta[key] = self.current_node.meta[key]
# Only copy "val" if we don't already have a good one
if "val" in self.current_node.meta and "val" not in new_arg.node.meta:
new_arg.node.meta["val"] = self.current_node.meta["val"]
return new_arg
else:
# Shouldn't happen if mapping is correct, but fallback
return super().placeholder(target, args, kwargs)
def output(self, target, args, kwargs) -> Any:
"""Transform output according to graph_output_map."""
original_outputs = args[0]
# Build new output list based on graph_output_map
new_outputs = []
for i in sorted(self.graph_output_map.keys()):
output_type, val = self.graph_output_map[i]
if output_type == "graph_out":
new_outputs.append(original_outputs[val])
elif output_type == "input":
input_idx = val.index
new_outputs.append(self.new_input_nodes[input_idx])
elif output_type == "constant":
new_outputs.append(val)
return super().output(target, (tuple(new_outputs),), {})
def run_node(self, node: Node) -> Any:
"""Run node transformation and preserve metadata."""
self.current_node = node
result = super().run_node(node)
# Copy important metadata
if hasattr(result, "node") and result.node is not node:
for key in ["val", "example_value", "unbacked_bindings"]:
if key in node.meta:
result.node.meta[key] = node.meta[key]
# Preserve node names (except output)
if node.op != "output" and hasattr(node, "name"):
result.node._rename(node.name)
return result
def transform(self) -> torch.fx.GraphModule:
"""Perform the graph transformation and copy module metadata."""
result_gm = super().transform()
# Copy module metadata like the original implementation
if hasattr(self.module, "meta"):
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
"dynamo_flat_name_to_original_fqn"
]
if "dynamo_compile_id" in self.module.meta:
result_gm.meta["dynamo_compile_id"] = self.module.meta[
"dynamo_compile_id"
]
return result_gm
def _dynamo_graph_capture_for_export(
mod: Callable[..., Any],
*,
constraints: Optional[list[Constraint]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
) -> Callable[..., torch.fx.GraphModule]:
"""
Improved dynamo graph capture using transformer approach with proper fake tensor handling.
This function creates a capture instance that handles:
1. PyTree flattening/unflattening with proper input ordering
2. Dynamo graph capture with export-specific context
3. FX graph transformation for export compatibility
4. Proper fake tensor metadata preservation
5. Dynamic dimension constraint handling
Notable improvements over manual approach:
- Uses FX Transformer for cleaner graph manipulation
- Properly handles fake tensor metadata and dynamic dimensions
- Preserves all necessary metadata for export
- More robust error handling and edge case management
TODO:
1. Are we actually gonna run the bytecode?
2. Need to attach guards
"""
_dynamic_shapes = dynamic_shapes
_constraints = constraints
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
module_to_trace = ModuleToTrace(mod, in_spec)
signature = inspect.signature(module_to_trace.forward)
bound_arguments = signature.bind(*flat_inputs)
bound_arguments.apply_defaults()
constraints: Optional[list[Constraint]] = _constraints
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = (
_dynamic_shapes
)
from . import reset # type: ignore[attr-defined]
reset()
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
frame = FrameInfo(
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
f_locals,
builtins, # type: ignore[arg-type]
closure=(), # type: ignore[arg-type]
)
dynamo_config_ctx = torch._dynamo.config.patch(
specialize_int=True,
specialize_float=True,
assume_static_by_default=True,
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
prefer_deferred_runtime_asserts_over_guards=False,
log_graph_in_out_metadata=True,
)
with (
compile_context(CompileContext(get_compile_id({}))),
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
dynamo_config_ctx,
):
out = fullgraph_capture(
frame,
constraints=_constraints,
_is_export_deprecated_do_not_use=True,
)
assert out.dynamo_output.tracer_output.output_graph is not None
# Extract export metadata from the new location
export_metadata = (
out.dynamo_output.tracer_output.output_graph.export_metadata
)
graph_inputs = export_metadata.graph_input_idx_to_local_source
graph_output_map = export_metadata.output_return_type
out_spec = export_metadata.out_spec
module_call_spec = export_metadata.module_call_spec
example_inputs: list[Any] = []
if out.backend_input is not None:
graph = out.backend_input.graph_module
fake_mode = out.backend_input.fake_mode
example_inputs = out.backend_input.example_inputs
else:
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
graph.graph.output(None)
graph.recompile()
fake_mode = out.dynamo_output.tracer_output.output_graph.fake_mode
# Compute dynamic dimensions for each input based on constraints
flat_args_dynamic_dims = [
{
c.dim
for c in (constraints or ())
if (
c.t_id == id(x)
and not isinstance(c, _RelaxedConstraint)
and c.constraint_range.vr.lower != c.constraint_range.vr.upper
)
}
for x in flat_inputs
]
# Create input order mapping from dynamo's internal order to user order
graph_input_order: dict[int, int] = {}
for inp in graph_inputs:
source = graph_inputs[inp]
assert isinstance(source, torch._dynamo.source.GetItemSource)
graph_input_order[source.index] = len(graph_input_order)
for real_idx, graph_idx in graph_input_order.items():
flat_inputs[real_idx] = example_inputs[graph_idx]
# Use FX transformer to rebuild the graph cleanly
transformed_graph = DynamoGraphTransformer(
graph,
flat_inputs,
flat_args_dynamic_dims,
graph_input_order,
graph_output_map,
fake_mode,
).transform()
# Set up PyTree codegen for proper input/output handling
transformed_graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
argument_names(inspect.signature(mod.forward), args, kwargs), # type: ignore[attr-defined, arg-type]
in_spec,
out_spec,
)
)
transformed_graph.recompile()
clean_nn_module_stack(transformed_graph)
clean_export_root(transformed_graph)
transformed_graph.meta["module_call_specs"] = module_call_spec
constraint_violation_error = None
try:
# Check if we have any constraint violations
check_fn = out.dynamo_output.build_guards(
module_to_trace.forward.__code__
).guard_manager
check_fn.check(f_locals)
except (
ConstraintViolationError,
torch.utils._sympy.value_ranges.ValueRangeError,
) as e:
constraint_violation_error = e
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not isinstance(
module_to_trace.forward,
(torch._ops.OpOverloadPacket, torch._ops.OpOverload),
)
):
dim_constraints.solve()
forced_specializations = dim_constraints.forced_specializations()
msg = dim_constraints.prettify_results(
inspect.signature(mod.forward), # type: ignore[attr-defined]
dynamic_shapes,
constraint_violation_error,
forced_specializations,
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "
'Set TORCH_LOGS="+export" for more information.'
)
if constraint_violation_error:
raise constraint_violation_error
return transformed_graph
return inner