[export] Unified graph capture with fullgraph_capture. (#165562)

Summary:
_dynamo_graph_capture_for_export in the current form has the compability issue
with the main torch.compile() path despite we reuse fullgraph_capture as the
bytecode tracer. The reason is that we flip on many export specific flags
and even trace with a wrapped function which will cause divergence with
torch.compile() again.

This PR instead creates a new implementation of dynamo_graph_capture_for_export
which 100% relies on fullgraph capture and post-processing on CaptureOutput so
that we can avoid the inversion of phases in PT2 compiler stack.

This also benefits precompile workflow since we want to have a feature that
only accepts pytree inputs and ship portable python wrappers in package. In
other words, I think the code here is sharable between export and precompile
for exporting portable graph.

Test Plan:
===================================================================== test session starts =====================================================================
platform linux -- Python 3.12.11, pytest-7.3.2, pluggy-1.6.0
rootdir: /data/users/zhxchen17/pytorch
configfile: pytest.ini
plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, subtests-0.13.1, rerunfailures-14.0, flakefinder-1.1.0, cpp-2.3.0, anyio-4.10.0
collected 9 items
Running 9 items in this shard

test/distributed/tensor/test_dtensor_export.py ........x                                                                                                [100%]

================================================================ 8 passed, 1 xfailed in 11.42s ================================================================

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165562
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
zhxchen17 2025-10-22 20:44:51 +00:00 committed by PyTorch MergeBot
parent 291712026b
commit 757975ad50
6 changed files with 297 additions and 11 deletions

View File

@ -6,7 +6,10 @@ import unittest
import torch
import torch.distributed as dist
import torch.fx.traceback as fx_traceback
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
dynamo_graph_capture_for_export,
)
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch._guards import tracing, TracingContext
@ -96,6 +99,13 @@ def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs):
gm = dynamo_graph_capture_for_export(model)(inputs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):
return aot_export_joint_with_descriptors_alone(gm, inputs)
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
@ -288,6 +298,7 @@ class DTensorExportTest(TestCase):
@parametrize(
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors_v2,
graph_capture_and_aot_export_joint_with_descriptors,
aot_export_joint_with_descriptors_alone,
],
@ -307,7 +318,21 @@ class DTensorExportTest(TestCase):
def test_annotate_aot_export_joint_with_descriptors_alone(self):
self._run_test(aot_export_joint_with_descriptors_alone, True)
def test_dynamic_shapes(self):
@parametrize(
"export_fn_with_answer",
[
(
graph_capture_and_aot_export_joint_with_descriptors_v2,
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
),
(
graph_capture_and_aot_export_joint_with_descriptors,
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
),
],
)
def test_dynamic_shapes(self, export_fn_with_answer):
export_fn, answer = export_fn_with_answer
dp_degree = 2
tp_degree = self.world_size // dp_degree
@ -331,7 +356,7 @@ class DTensorExportTest(TestCase):
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(tp_model, inputs)
joint_gm = export_fn(tp_model, inputs)
res = []
for node in joint_gm.graph.nodes:
@ -341,12 +366,16 @@ class DTensorExportTest(TestCase):
if isinstance(fake_val, torch._subclasses.fake_tensor.FakeTensor):
res.append(list(fake_val.shape))
self.assertExpectedInline(
str(res),
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
)
self.assertEqual(str(res), answer)
def test_einsum_dtensor_export(self):
@parametrize(
"export_fn",
[
dynamo_graph_capture_for_export,
_dynamo_graph_capture_for_export,
],
)
def test_einsum_dtensor_export(self, export_fn):
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
world_size = 4
# Create device mesh
@ -366,9 +395,7 @@ class DTensorExportTest(TestCase):
output = model(x_dtensor, y_dtensor, z_dtensor)
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(
x_dtensor, y_dtensor, z_dtensor
)
gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor)
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
self.assertEqual(output, output_gm)

View File

@ -471,6 +471,67 @@ from user code:
assert hasattr(backend_result.compiled_fn, "serialize")
self.assertIsNotNone(backend_result.compiled_fn.serialize)
def test_fullgraph_capture_with_pytree_module(self):
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
self.linear1 = torch.nn.Linear(3, 3)
self.linear2 = torch.nn.Linear(3, 3)
self.linear3 = torch.nn.Linear(3, 3)
def forward(self, x):
return {
"y": self.linear2(x[2] + 1),
"z": self.linear3(x[1] - 1),
"w": self.linear(x[0]["b"] + 2),
"v": self.linear1(x[0]["a"] - 2),
}
mod = Module()
compiled_mod = dynamo_graph_capture_for_export(mod)(
(
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
torch.randn(3, 3),
torch.randn(3, 3),
)
)
inputs = (
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
torch.randn(3, 3),
torch.randn(3, 3),
)
self.assertEqual(compiled_mod(inputs), mod(inputs))
def test_fullgraph_capture_with_pytree_func(self):
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
def foo(x):
return {
"y": x[2] + 1,
"z": x[1] - 1,
"w": x[0]["b"] + 2,
"v": x[0]["a"] - 2,
}
compiled_foo = dynamo_graph_capture_for_export(foo)(
(
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
torch.randn(2, 3),
torch.randn(3, 4),
)
)
inputs = (
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
torch.randn(2, 3),
torch.randn(3, 4),
)
self.assertEqual(compiled_foo(inputs), foo(inputs))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -402,6 +402,43 @@ def forward(self, x):
self.assertEqual(res_export, res_eager)
def test_dynamo_graph_capture(self):
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
class Foo(torch.nn.Module):
def forward(self, dct, lst, bleh):
x = dct["a"] * lst[1][0]
y = dct["b"] * lst[0]
out_dict = {}
# Mutate and get a new entry in there
lst_copy = lst.copy()
lst_copy.append(lst[0])
out_dict["a"] = x
out_dict["b"] = y
return (
dct["a"],
out_dict["b"],
bleh,
lst_copy[-1],
out_dict["a"],
[5, 6],
)
foo = Foo()
def make_inputs():
return (
{"a": torch.randn(2, 3), "b": torch.randn(2, 3)},
[torch.randn(2, 3), (torch.randn(2, 3),)],
torch.randn(2, 3),
)
trace_inputs = make_inputs()
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
test_inputs = make_inputs()
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
if __name__ == "__main__":
run_tests()

View File

@ -886,6 +886,7 @@ class DynamoOutput:
return GraphCaptureOutput(
OutputGraphCommon(
output_graph.dump_guards_state(),
output_graph.import_sources,
output_graph.shape_env,
output_graph.export_metadata,
output_graph.tracked_fakes_id_to_source,
@ -960,6 +961,27 @@ class CaptureOutput:
# BackendInput can be None when dynamo didn't compile any graph (no tensor op)
backend_input: Optional[BackendInput]
def forward_callable(self) -> Callable[..., Any]:
import importlib
# TODO code sharing
import_sources = self.graph_capture_output.output_graph.import_sources
assert self.backend_input is not None
backend_id = self.backend_input.backend_id
import_sources = {
alias: importlib.import_module(module_name)
for alias, module_name in import_sources.items()
}
f_globals = {
**import_sources,
backend_id: self.backend_input.graph_module,
}
return types.FunctionType(
self.graph_capture_output.bytecode,
f_globals,
closure=(),
)
def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
"""

View File

@ -1,6 +1,8 @@
import copy
import inspect
import logging
import traceback
import types
from collections import namedtuple
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
@ -26,6 +28,7 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
if TYPE_CHECKING:
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils._pytree import TreeSpec
log = logging.getLogger(__name__)
@ -446,6 +449,140 @@ def _suggest_or_raise_constraint_violation(
raise constraint_violation_error
def pytreeify(
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Any:
"""
Given a dynamo capture output, return a callable graph module that
contain the following information:
1. input/output pytree spec
2. input/output shuffle functions
Input shuffle functions are the converters taking pytree falttened inputs
and reorder them to the calling convention of dynamo raw graph module.
Output shuffle functions are the converters taking the outputs of the
dynamo raw graph module and convert them to the pytree format.
This function will replay any side effects that happened during the bytecode,
so it is important to check against side effects before calling this function.
"""
assert out.backend_input is not None
backend_input = out.backend_input
backend = out.backend_input.graph_module
if isinstance(mod, torch.nn.Module):
args = (mod,) + args
elif inspect.ismethod(mod):
args = (mod.__self__,) + args
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
class Yield(Exception):
pass
class InShuffle(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = mod
self.num_inputs = len(flat_real_args)
self.gm_inputs = None
def forward(self, *flat_proxy_args):
args, kwargs = pytree.tree_unflatten(
[flat_proxy_args[i] for i in range(self.num_inputs)], in_spec
)
def backend_dummy(*example_inputs):
self.gm_inputs = example_inputs
raise Yield
backend_input.graph_module = backend_dummy # type: ignore[assignment]
try:
out.forward_callable()(*args, **kwargs)
except Yield:
assert self.gm_inputs is not None
return self.gm_inputs
finally:
backend_input.graph_module = backend
raise RuntimeError
in_shuffle_graph = torch.fx.symbolic_trace(InShuffle())
class OutShuffle(torch.nn.Module):
def __init__(self):
super().__init__()
self.num_inputs = len(flat_real_args)
self.num_outputs = len(
next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0]
)
self.out_spec: Optional[TreeSpec] = None
def forward(self, *flat_proxy_args):
args, kwargs = pytree.tree_unflatten(
[flat_proxy_args[i] for i in range(self.num_inputs)], in_spec
)
def backend_dummy(*example_inputs):
return [
flat_proxy_args[self.num_inputs + i]
for i in range(self.num_outputs)
]
backend_input.graph_module = backend_dummy # type: ignore[assignment]
try:
results = out.forward_callable()(*args, **kwargs)
finally:
backend_input.graph_module = backend
ret, self.out_spec = pytree.tree_flatten(results)
return ret
out_shuffle = OutShuffle()
out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle)
def pytree_call(*args, **kwargs):
import torch.export._unlift
flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs))
if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec):
raise RuntimeError(
f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}"
)
flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args))
assert out_shuffle.out_spec is not None
return pytree.tree_unflatten(
out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec
)
if isinstance(mod, torch.nn.Module):
compiled_mod = copy.copy(mod)
compiled_mod.forward = types.MethodType(pytree_call, compiled_mod)
if not hasattr(compiled_mod, "meta"):
compiled_mod.meta = {} # type: ignore[attr-defined]
if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta:
compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode
return compiled_mod
elif inspect.ismethod(mod):
return types.MethodType(pytree_call, mod.__self__)
else:
return pytree_call
def dynamo_graph_capture_for_export(
mod: Callable[..., Any],
) -> Callable[..., Any]:
def inner(*args: Any, **kwargs: Any) -> Any:
with (
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
):
out = fullgraph_capture(mod, args, kwargs)
# TODO filter out side effects.
return pytreeify(out, mod, args, kwargs)
return inner
def _dynamo_graph_capture_for_export(
mod: Callable[..., Any],
*,

View File

@ -463,6 +463,7 @@ class OutputGraphCommon(OutputGraphGuardsState):
def __init__(
self,
output_graph_guards_state: OutputGraphGuardsState,
import_sources: Optional[dict[str, str]] = None,
shape_env: Optional[ShapeEnv] = None,
export_metadata: Optional[ExportMetaData] = None,
tracked_fakes_id_to_source: Optional[dict[int, list[Source]]] = None,
@ -485,6 +486,7 @@ class OutputGraphCommon(OutputGraphGuardsState):
output_graph_guards_state.name_of_builtins_dict_key_in_fglobals,
)
self.import_sources = import_sources or {}
# The following fields are currently known to be used by clients.
# In particular, we need:
# - shape_env, for building guards