New export implementation with flat inp/out (#162167)

This is my first attempt of building new export API. The main thing it addresses is correctly getting input and output relations. Subsequent diffs willl add functionality for dynamic shapes, nn_module_stack etc.

Differential Revision: [D81793205](https://our.internmc.facebook.com/intern/diff/D81793205)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162167
Approved by: https://github.com/zhxchen17, https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2025-09-04 11:03:53 -07:00 committed by PyTorch MergeBot
parent ae0edc133e
commit 047603d35b
8 changed files with 387 additions and 87 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: export"]
# flake8: noqa
import copy
import types
import unittest
from typing import Dict, List, Tuple
@ -351,6 +352,62 @@ def forward(self, x):
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
self.assertTrue(torch.allclose(res, res2))
def test_export_add_in_out_info(self):
class Foo(torch.nn.Module):
def forward(self, dct, lst, bleh):
x = dct["a"] * lst[1][0]
y = dct["b"] * lst[0]
out_dict = {}
# Mutate and get a new entry in there
lst_copy = lst.copy()
lst_copy.append(lst[0])
out_dict["a"] = x
out_dict["b"] = y
return (
dct["a"],
out_dict["b"],
bleh,
lst_copy[-1],
out_dict["a"],
[5, 6],
)
dct = {"a": torch.randn(2, 3), "b": torch.randn(2, 3)}
lst = [torch.randn(2, 3), [torch.randn(2, 3), torch.randn(2, 3)]]
export_inputs = ((dct, lst, 56), {})
eager_inputs = copy.deepcopy(export_inputs)
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
graph_module = _dynamo_graph_capture_for_export(Foo())(
*export_inputs[0], **export_inputs[1]
)
res_export = graph_module(*export_inputs[0], **export_inputs[1])
res_eager = Foo()(*eager_inputs[0], **eager_inputs[1])
self.assertEqual(res_export, res_eager)
def test_export_leaf(self):
class Foo(torch.nn.Module):
def forward(self, x):
return x.sin()
export_inputs = ((torch.randn(4, 4),), {})
eager_inputs = copy.deepcopy(export_inputs)
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
graph_module = _dynamo_graph_capture_for_export(Foo())(
*export_inputs[0], **export_inputs[1]
)
res_export = graph_module(*export_inputs[0], **export_inputs[1])
res_eager = Foo()(*eager_inputs[0], **eager_inputs[1])
self.assertEqual(res_export, res_eager)
if __name__ == "__main__":
run_tests()

View File

@ -10,7 +10,14 @@ seamlessly optimize PyTorch programs, including those using modern Python featur
import torch
from . import aot_compile, config, convert_frame, eval_frame, resume_execution
from . import (
aot_compile,
config,
convert_frame,
eval_frame,
functional_export,
resume_execution,
)
from .backends.registry import list_backends, lookup_backend, register_backend
from .callback import callback_handler, on_compile_end, on_compile_start
from .code_context import code_context

View File

@ -110,6 +110,12 @@ automatic_dynamic_shapes = True
# Valid options: "dynamic", "unbacked"
automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic"
# log graph in/out metadata
# This is only turned on for export today since we
# know we are tracing a flat callable. later, this
# can extended to other use cases as well.
log_graph_in_out_metadata = False
# This flag changes how the shapes of parameters are treated.
# If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic
# If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static,

View File

@ -906,7 +906,9 @@ class FrameInfo:
closure: tuple[CellType]
def fullgraph_capture(frame: FrameInfo) -> CaptureOutput:
def fullgraph_capture(
frame: FrameInfo, *, _is_export_deprecated_do_not_use: bool = False
) -> CaptureOutput:
"""
A standalone function which takes a frame and returns dynamo captured graph
plus other important compile information. This should serve as the common
@ -948,6 +950,7 @@ def fullgraph_capture(frame: FrameInfo) -> CaptureOutput:
frame.builtins,
frame.closure,
compiler_fn=fullgraph_compiler,
export=_is_export_deprecated_do_not_use,
one_graph=True,
restart_reasons=set(),
)

View File

@ -1124,6 +1124,89 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
return fn
# Make dynamo graph to have same input/output spec as user code
def argument_names(
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
) -> list[str]:
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
# Get a list of Parameter objects from the Signature object
params = list(sig.parameters.values())
# Separate positional arguments, keyword-only arguments and varargs/varkw
args = [
p.name for p in params if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
kwonlyargs = [
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
]
varargs = next(
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
None,
)
varkw = next(
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
None,
)
# Get default values for positional arguments and keyword-only arguments
defaults = tuple(
p.default
for p in params
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and p.default is not inspect.Parameter.empty
)
kwonlydefaults = {
p.name: p.default
for p in params
if p.kind == inspect.Parameter.KEYWORD_ONLY
and p.default is not inspect.Parameter.empty
}
# Get annotations for parameters and return value
annotations = {}
if sig.return_annotation:
annotations = {"return": sig.return_annotation}
for parameter in params:
annotations[parameter.name] = parameter.annotation
# Return a FullArgSpec object with the extracted attributes
return inspect.FullArgSpec(
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
)
fullargspec = signature_to_fullargspec(f_sig)
# 1. Map `args` 1-to-1 to positional arguments in original signature.
input_strs = fullargspec.args[: len(args)]
if len(args) > len(fullargspec.args):
# 2. If there are more arguments left in `args`, they map to varargs in original
# signature. Assign names as {varargs}_0, {varargs}_1, ...
assert fullargspec.varargs is not None, "More arguments than expected"
input_strs += [
f"{fullargspec.varargs}_{i}" for i in range(0, len(args) - len(input_strs))
]
elif len(args) < len(fullargspec.args):
# 3. If there are fewer arguments in `args` than `fullargspec.args`,
# it implies these are arguments either with default values, or provided in
# `kwargs`. The former can be safely ignored. Because Dynamo.export does not
# export them as part of the function signature. The latter will be handled
# in the next step.
for unprovided_arg in fullargspec.args[
len(args) : -len(fullargspec.defaults or [])
]:
assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
# 4. Keyword arguments provided in `kwargs`.
input_strs += list(kwargs.keys())
# 5. Keyword-only arguments with default values if not provided are not exported
# as part of the function signature.
for kwonly_arg in fullargspec.kwonlyargs:
kwonlydefaults = fullargspec.kwonlydefaults or {}
assert kwonly_arg in kwargs or kwonly_arg in kwonlydefaults, (
f"Missing keyword only argument {kwonly_arg}"
)
return input_strs
def check_if_dynamo_supported() -> None:
if sys.version_info >= (3, 14):
raise RuntimeError("Python 3.14+ not yet supported for torch.compile")
@ -1650,91 +1733,6 @@ def rewrite_signature(
fake_mode,
).transform()
# Make dynamo graph to have same input/output spec as user code
def argument_names(
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
) -> list[str]:
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
# Get a list of Parameter objects from the Signature object
params = list(sig.parameters.values())
# Separate positional arguments, keyword-only arguments and varargs/varkw
args = [
p.name
for p in params
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
kwonlyargs = [
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
]
varargs = next(
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
None,
)
varkw = next(
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
None,
)
# Get default values for positional arguments and keyword-only arguments
defaults = tuple(
p.default
for p in params
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and p.default is not inspect.Parameter.empty
)
kwonlydefaults = {
p.name: p.default
for p in params
if p.kind == inspect.Parameter.KEYWORD_ONLY
and p.default is not inspect.Parameter.empty
}
# Get annotations for parameters and return value
annotations = {}
if sig.return_annotation:
annotations = {"return": sig.return_annotation}
for parameter in params:
annotations[parameter.name] = parameter.annotation
# Return a FullArgSpec object with the extracted attributes
return inspect.FullArgSpec(
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
)
fullargspec = signature_to_fullargspec(f_sig)
# 1. Map `args` 1-to-1 to positional arguments in original signature.
input_strs = fullargspec.args[: len(args)]
if len(args) > len(fullargspec.args):
# 2. If there are more arguments left in `args`, they map to varargs in original
# signature. Assign names as {varargs}_0, {varargs}_1, ...
assert fullargspec.varargs is not None, "More arguments than expected"
input_strs += [
f"{fullargspec.varargs}_{i}"
for i in range(0, len(args) - len(input_strs))
]
elif len(args) < len(fullargspec.args):
# 3. If there are fewer arguments in `args` than `fullargspec.args`,
# it implies these are arguments either with default values, or provided in
# `kwargs`. The former can be safely ignored. Because Dynamo.export does not
# export them as part of the function signature. The latter will be handled
# in the next step.
for unprovided_arg in fullargspec.args[
len(args) : -len(fullargspec.defaults or [])
]:
assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
# 4. Keyword arguments provided in `kwargs`.
input_strs += list(kwargs.keys())
# 5. Keyword-only arguments with default values if not provided are not exported
# as part of the function signature.
for kwonly_arg in fullargspec.kwonlyargs:
kwonlydefaults = fullargspec.kwonlydefaults or {}
assert kwonly_arg in kwargs or kwonly_arg in kwonlydefaults, (
f"Missing keyword only argument {kwonly_arg}"
)
return input_strs
new_graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
argument_names(f_sig, orig_args, orig_kwargs),

View File

@ -0,0 +1,142 @@
import builtins
import inspect
from collections import namedtuple
from typing import Any, Callable
import torch
import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
from torch._dynamo.eval_frame import argument_names
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
class ModuleToTrace(torch.nn.Module):
def __init__(self, foo: Any, in_spec: Any) -> None:
super().__init__()
self._export_root = foo
self.in_spec = in_spec
def forward(self, *flat_args: Any) -> "ExportTracerOutput":
args, kwargs = pytree.tree_unflatten(flat_args, self.in_spec)
res = self._export_root(*args, **kwargs)
out_flat, out_spec = pytree.tree_flatten(res)
return ExportTracerOutput(out_flat, out_spec)
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
def _dynamo_graph_capture_for_export(
mod: torch.nn.Module,
) -> Callable[..., torch.fx.GraphModule]:
"""
This is lower level API that is used for export to capture dynamo level
torch IR.
Notable TODOs:
1. Are we actually gonna run the bytecode?
2. Need to attach guards
"""
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
module_to_trace = ModuleToTrace(mod, in_spec)
signature = inspect.signature(module_to_trace.forward)
bound_arguments = signature.bind(*flat_inputs)
bound_arguments.apply_defaults()
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
frame = FrameInfo(
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
f_locals,
builtins, # type: ignore[arg-type]
closure=(), # type: ignore[arg-type]
)
dynamo_config_ctx = torch._dynamo.config.patch(
"log_graph_in_out_metadata", True
)
with (
compile_context(CompileContext(get_compile_id({}))),
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
dynamo_config_ctx,
):
out = fullgraph_capture(frame, _is_export_deprecated_do_not_use=True)
assert out.dynamo_output.tracer_output.output_graph is not None
export_metadata = (
out.dynamo_output.tracer_output.output_graph.export_metadata
)
graph_inputs = export_metadata.graph_input_idx_to_local_source
output_return_type = export_metadata.output_return_type
# We need to extract out_spec here because we are not actually running the bytecode
out_spec = export_metadata.out_spec
graph = out.backend_input.graph_module
# It is not guaranteed that dynamo puts inputs in right order, so we need to
# map the actual user order to the dynamo order.
graph_input_order: dict[int, int] = {}
for inp in graph_inputs:
source = graph_inputs[inp]
assert isinstance(source, torch._dynamo.source.GetItemSource)
graph_input_order[source.index] = len(graph_input_order)
placeholders = [n for n in list(graph.graph.nodes) if n.op == "placeholder"]
output = next(n for n in list(graph.graph.nodes) if n.op == "output")
# Sometimes there can be empty inputs
anchor = placeholders[0] if len(placeholders) > 0 else output
inp_to_node = {}
with graph.graph.inserting_before(anchor):
for i in range(len(flat_inputs)):
node_new = graph.graph.placeholder(f"arg_{i}")
if i in graph_input_order:
placeholders[graph_input_order[i]]
node_new.meta = placeholders[graph_input_order[i]].meta.copy()
inp_to_node[i] = node_new
new_args = []
for i in output_return_type:
type, val = output_return_type[i]
if type == "graph_out":
new_args.append(output.args[0][val])
if type == "input":
input_idx = val.index
new_args.append(inp_to_node[input_idx])
if type == "constant":
new_args.append(val)
output.args = (tuple(new_args),)
for src_idx, i in graph_input_order.items():
old = placeholders[src_idx]
new = inp_to_node[i]
old.replace_all_uses_with(new)
graph.graph.erase_node(old)
# Dynamo uses _lazyGraphModule, so we need to force recompile
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(graph)
graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
argument_names(signature, args, kwargs), # type: ignore[arg-type]
in_spec,
out_spec,
)
)
graph.recompile()
return graph
return inner

View File

@ -363,6 +363,24 @@ class StackLocalsMetadata:
locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list)
# TODO we should expand this to make it work for atribtrary in/out
@dataclass
class ExportMetaData:
# maps graph input index to its' source which is later
# used in export to map to correct user input. In its' flat form,
# just looks like GetItem(base=LocalSource("foo", idx=0))
graph_input_idx_to_local_source: dict[int, Source] = dc_field(default_factory=dict)
# maps user output idx to what type of output it is. There are 3 options:
# 1) graph out
# 2) user input
# 3) constants
output_return_type: dict[int, tuple[str, Any]] = dc_field(default_factory=dict)
# output spec of the traced function
out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
torch.utils._pytree._LEAF_SPEC
)
def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
# f_globals["__builtins__"] can be a dict or a module. This is an
# implementation detail -
@ -598,6 +616,8 @@ class OutputGraph(OutputGraphGuardsState):
# mangled alias -> module fqn name
self.import_sources: dict[str, str] = {}
self.export_metadata = ExportMetaData()
def mark_bytecode_tracing_start(self) -> None:
self.compiler_trace_stack.enter_context(
dynamo_timed(
@ -1494,6 +1514,54 @@ class OutputGraph(OutputGraphGuardsState):
)
self.codegen_suffix(tx, stack_values_flat, pass2)
if (
torch._dynamo.config.log_graph_in_out_metadata
and stack_values_flat
and len(stack_values_flat) == 1
):
vt = stack_values_flat[0]
if (
isinstance(vt, torch._dynamo.variables.NamedTupleVariable)
and vt.tuple_cls
is torch._dynamo.functional_export.ExportTracerOutput
):
flat_returns = vt.items[0]
out_spec = vt.items[1]
assert isinstance(
flat_returns, torch._dynamo.variables.ListVariable
)
vt_to_graph_out_idx: dict[VariableTracker, int] = {}
for value in pass2.graph_outputs.values():
assert isinstance(value, torch._dynamo.codegen.GraphOutputEntry)
variable: VariableTracker = value.variable
vt_to_graph_out_idx[variable] = value.index
for idx, vt in enumerate(flat_returns.items):
if vt in vt_to_graph_out_idx:
self.export_metadata.output_return_type[idx] = (
"graph_out",
vt_to_graph_out_idx[vt],
)
elif (
vt.source is not None
and (source := getattr(vt.source, "base", None))
and source.is_input
):
self.export_metadata.output_return_type[idx] = (
"input",
vt.source,
)
elif isinstance(vt, torch._dynamo.variables.ConstantVariable):
self.export_metadata.output_return_type[idx] = (
"constant",
vt.as_python_constant(),
)
else:
assert f"Encountered unrecognized type {vt} at output {idx}" # noqa: PLW0129
self.export_metadata.out_spec = out_spec.as_python_constant()
output = []
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
output.extend(
@ -2039,6 +2107,10 @@ class OutputGraph(OutputGraphGuardsState):
assert self.root_tx is not None
cg = PyCodegen(self.root_tx)
for idx, arg in enumerate(self.graphargs):
self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
cg.make_call_generated_code(name)
return cg.get_instructions()

View File

@ -1748,6 +1748,21 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
ctor = self.python_type()
return ctor(*args, **kwargs)
def reconstruct(self, codegen: "PyCodegen") -> None:
# Handle specific pytree classes
import torch.utils._pytree as pytree
if self.value_type is pytree.LeafSpec:
# Create a new LeafSpec instance by calling the constructor
codegen.add_push_null(
lambda: codegen.load_import_from("torch.utils._pytree", "LeafSpec")
)
codegen.extend_output(create_call_function(0, False))
return
# For other frozen dataclasses, fall back to the base class behavior
super().reconstruct(codegen)
# NB: This is called during __init__ for a frozen dataclass
# use this to accumulate the most up-to-date field values
def method_setattr_standard(self, tx: "InstructionTranslator", name, value):