From 757975ad50b6c1b20091856bee16f248a9eff73c Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Wed, 22 Oct 2025 20:44:51 +0000 Subject: [PATCH] [export] Unified graph capture with fullgraph_capture. (#165562) Summary: _dynamo_graph_capture_for_export in the current form has the compability issue with the main torch.compile() path despite we reuse fullgraph_capture as the bytecode tracer. The reason is that we flip on many export specific flags and even trace with a wrapped function which will cause divergence with torch.compile() again. This PR instead creates a new implementation of dynamo_graph_capture_for_export which 100% relies on fullgraph capture and post-processing on CaptureOutput so that we can avoid the inversion of phases in PT2 compiler stack. This also benefits precompile workflow since we want to have a feature that only accepts pytree inputs and ship portable python wrappers in package. In other words, I think the code here is sharable between export and precompile for exporting portable graph. Test Plan: ===================================================================== test session starts ===================================================================== platform linux -- Python 3.12.11, pytest-7.3.2, pluggy-1.6.0 rootdir: /data/users/zhxchen17/pytorch configfile: pytest.ini plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, subtests-0.13.1, rerunfailures-14.0, flakefinder-1.1.0, cpp-2.3.0, anyio-4.10.0 collected 9 items Running 9 items in this shard test/distributed/tensor/test_dtensor_export.py ........x [100%] ================================================================ 8 passed, 1 xfailed in 11.42s ================================================================ Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/165562 Approved by: https://github.com/tugsbayasgalan --- .../distributed/tensor/test_dtensor_export.py | 49 +++++-- test/dynamo/test_aot_compile.py | 61 ++++++++ test/export/test_experimental.py | 37 +++++ torch/_dynamo/convert_frame.py | 22 +++ torch/_dynamo/functional_export.py | 137 ++++++++++++++++++ torch/_dynamo/output_graph.py | 2 + 6 files changed, 297 insertions(+), 11 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index 4f339e43847..1f25090e576 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -6,7 +6,10 @@ import unittest import torch import torch.distributed as dist import torch.fx.traceback as fx_traceback -from torch._dynamo.functional_export import _dynamo_graph_capture_for_export +from torch._dynamo.functional_export import ( + _dynamo_graph_capture_for_export, + dynamo_graph_capture_for_export, +) from torch._functorch.aot_autograd import aot_export_joint_with_descriptors from torch._functorch.partitioners import min_cut_rematerialization_partition from torch._guards import tracing, TracingContext @@ -96,6 +99,13 @@ def strict_export_and_aot_export_joint_with_descriptors(model, inputs): return aot_export_joint_with_descriptors_alone(ep.module(), inputs) +def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs): + gm = dynamo_graph_capture_for_export(model)(inputs) + fake_mode = gm.meta.get("fake_mode", None) + with tracing(TracingContext(fake_mode)): + return aot_export_joint_with_descriptors_alone(gm, inputs) + + def graph_capture_and_aot_export_joint_with_descriptors(model, inputs): with torch._dynamo.config.patch(install_free_tensors=True): # TODO: switch to use the official graph_capture API once it is ready @@ -288,6 +298,7 @@ class DTensorExportTest(TestCase): @parametrize( "export_fn", [ + graph_capture_and_aot_export_joint_with_descriptors_v2, graph_capture_and_aot_export_joint_with_descriptors, aot_export_joint_with_descriptors_alone, ], @@ -307,7 +318,21 @@ class DTensorExportTest(TestCase): def test_annotate_aot_export_joint_with_descriptors_alone(self): self._run_test(aot_export_joint_with_descriptors_alone, True) - def test_dynamic_shapes(self): + @parametrize( + "export_fn_with_answer", + [ + ( + graph_capture_and_aot_export_joint_with_descriptors_v2, + "[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]", + ), + ( + graph_capture_and_aot_export_joint_with_descriptors, + "[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]", + ), + ], + ) + def test_dynamic_shapes(self, export_fn_with_answer): + export_fn, answer = export_fn_with_answer dp_degree = 2 tp_degree = self.world_size // dp_degree @@ -331,7 +356,7 @@ class DTensorExportTest(TestCase): inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()]) torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100) - joint_gm = graph_capture_and_aot_export_joint_with_descriptors(tp_model, inputs) + joint_gm = export_fn(tp_model, inputs) res = [] for node in joint_gm.graph.nodes: @@ -341,12 +366,16 @@ class DTensorExportTest(TestCase): if isinstance(fake_val, torch._subclasses.fake_tensor.FakeTensor): res.append(list(fake_val.shape)) - self.assertExpectedInline( - str(res), - """[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""", - ) + self.assertEqual(str(res), answer) - def test_einsum_dtensor_export(self): + @parametrize( + "export_fn", + [ + dynamo_graph_capture_for_export, + _dynamo_graph_capture_for_export, + ], + ) + def test_einsum_dtensor_export(self, export_fn): """Test exporting a model with einsum that has DTensor inputs/outputs with side effects""" world_size = 4 # Create device mesh @@ -366,9 +395,7 @@ class DTensorExportTest(TestCase): output = model(x_dtensor, y_dtensor, z_dtensor) with torch._dynamo.config.patch(install_free_tensors=True): # TODO: switch to use the official graph_capture API once it is ready - gm = _dynamo_graph_capture_for_export(model)( - x_dtensor, y_dtensor, z_dtensor - ) + gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor) output_gm = gm(x_dtensor, y_dtensor, z_dtensor) self.assertEqual(output, output_gm) diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index d543fe76d65..12efbdab94f 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -471,6 +471,67 @@ from user code: assert hasattr(backend_result.compiled_fn, "serialize") self.assertIsNotNone(backend_result.compiled_fn.serialize) + def test_fullgraph_capture_with_pytree_module(self): + from torch._dynamo.functional_export import dynamo_graph_capture_for_export + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) + self.linear3 = torch.nn.Linear(3, 3) + + def forward(self, x): + return { + "y": self.linear2(x[2] + 1), + "z": self.linear3(x[1] - 1), + "w": self.linear(x[0]["b"] + 2), + "v": self.linear1(x[0]["a"] - 2), + } + + mod = Module() + compiled_mod = dynamo_graph_capture_for_export(mod)( + ( + {"a": torch.randn(3, 3), "b": torch.randn(3, 3)}, + torch.randn(3, 3), + torch.randn(3, 3), + ) + ) + + inputs = ( + {"a": torch.randn(3, 3), "b": torch.randn(3, 3)}, + torch.randn(3, 3), + torch.randn(3, 3), + ) + self.assertEqual(compiled_mod(inputs), mod(inputs)) + + def test_fullgraph_capture_with_pytree_func(self): + from torch._dynamo.functional_export import dynamo_graph_capture_for_export + + def foo(x): + return { + "y": x[2] + 1, + "z": x[1] - 1, + "w": x[0]["b"] + 2, + "v": x[0]["a"] - 2, + } + + compiled_foo = dynamo_graph_capture_for_export(foo)( + ( + {"a": torch.randn(4, 3), "b": torch.randn(3, 2)}, + torch.randn(2, 3), + torch.randn(3, 4), + ) + ) + + inputs = ( + {"a": torch.randn(4, 3), "b": torch.randn(3, 2)}, + torch.randn(2, 3), + torch.randn(3, 4), + ) + self.assertEqual(compiled_foo(inputs), foo(inputs)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 6e9379be092..2f4a370ac79 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -402,6 +402,43 @@ def forward(self, x): self.assertEqual(res_export, res_eager) + def test_dynamo_graph_capture(self): + from torch._dynamo.functional_export import dynamo_graph_capture_for_export + + class Foo(torch.nn.Module): + def forward(self, dct, lst, bleh): + x = dct["a"] * lst[1][0] + y = dct["b"] * lst[0] + out_dict = {} + + # Mutate and get a new entry in there + lst_copy = lst.copy() + lst_copy.append(lst[0]) + out_dict["a"] = x + out_dict["b"] = y + return ( + dct["a"], + out_dict["b"], + bleh, + lst_copy[-1], + out_dict["a"], + [5, 6], + ) + + foo = Foo() + + def make_inputs(): + return ( + {"a": torch.randn(2, 3), "b": torch.randn(2, 3)}, + [torch.randn(2, 3), (torch.randn(2, 3),)], + torch.randn(2, 3), + ) + + trace_inputs = make_inputs() + gm = dynamo_graph_capture_for_export(foo)(*trace_inputs) + test_inputs = make_inputs() + self.assertEqual(gm(*test_inputs), foo(*test_inputs)) + if __name__ == "__main__": run_tests() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index e1b4e051672..679705beec5 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -886,6 +886,7 @@ class DynamoOutput: return GraphCaptureOutput( OutputGraphCommon( output_graph.dump_guards_state(), + output_graph.import_sources, output_graph.shape_env, output_graph.export_metadata, output_graph.tracked_fakes_id_to_source, @@ -960,6 +961,27 @@ class CaptureOutput: # BackendInput can be None when dynamo didn't compile any graph (no tensor op) backend_input: Optional[BackendInput] + def forward_callable(self) -> Callable[..., Any]: + import importlib + + # TODO code sharing + import_sources = self.graph_capture_output.output_graph.import_sources + assert self.backend_input is not None + backend_id = self.backend_input.backend_id + import_sources = { + alias: importlib.import_module(module_name) + for alias, module_name in import_sources.items() + } + f_globals = { + **import_sources, + backend_id: self.backend_input.graph_module, + } + return types.FunctionType( + self.graph_capture_output.bytecode, + f_globals, + closure=(), + ) + def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]: """ diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 4af379cfe39..c2b2f3e43a5 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -1,6 +1,8 @@ +import copy import inspect import logging import traceback +import types from collections import namedtuple from typing import Any, Callable, Optional, TYPE_CHECKING, Union @@ -26,6 +28,7 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo if TYPE_CHECKING: from torch._subclasses.fake_tensor import FakeTensorMode + from torch.utils._pytree import TreeSpec log = logging.getLogger(__name__) @@ -446,6 +449,140 @@ def _suggest_or_raise_constraint_violation( raise constraint_violation_error +def pytreeify( + out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> Any: + """ + Given a dynamo capture output, return a callable graph module that + contain the following information: + 1. input/output pytree spec + 2. input/output shuffle functions + Input shuffle functions are the converters taking pytree falttened inputs + and reorder them to the calling convention of dynamo raw graph module. + Output shuffle functions are the converters taking the outputs of the + dynamo raw graph module and convert them to the pytree format. + + This function will replay any side effects that happened during the bytecode, + so it is important to check against side effects before calling this function. + """ + assert out.backend_input is not None + backend_input = out.backend_input + backend = out.backend_input.graph_module + + if isinstance(mod, torch.nn.Module): + args = (mod,) + args + elif inspect.ismethod(mod): + args = (mod.__self__,) + args + + flat_real_args, in_spec = pytree.tree_flatten((args, kwargs)) + + class Yield(Exception): + pass + + class InShuffle(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = mod + self.num_inputs = len(flat_real_args) + self.gm_inputs = None + + def forward(self, *flat_proxy_args): + args, kwargs = pytree.tree_unflatten( + [flat_proxy_args[i] for i in range(self.num_inputs)], in_spec + ) + + def backend_dummy(*example_inputs): + self.gm_inputs = example_inputs + raise Yield + + backend_input.graph_module = backend_dummy # type: ignore[assignment] + try: + out.forward_callable()(*args, **kwargs) + except Yield: + assert self.gm_inputs is not None + return self.gm_inputs + finally: + backend_input.graph_module = backend + raise RuntimeError + + in_shuffle_graph = torch.fx.symbolic_trace(InShuffle()) + + class OutShuffle(torch.nn.Module): + def __init__(self): + super().__init__() + self.num_inputs = len(flat_real_args) + self.num_outputs = len( + next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0] + ) + self.out_spec: Optional[TreeSpec] = None + + def forward(self, *flat_proxy_args): + args, kwargs = pytree.tree_unflatten( + [flat_proxy_args[i] for i in range(self.num_inputs)], in_spec + ) + + def backend_dummy(*example_inputs): + return [ + flat_proxy_args[self.num_inputs + i] + for i in range(self.num_outputs) + ] + + backend_input.graph_module = backend_dummy # type: ignore[assignment] + try: + results = out.forward_callable()(*args, **kwargs) + finally: + backend_input.graph_module = backend + ret, self.out_spec = pytree.tree_flatten(results) + return ret + + out_shuffle = OutShuffle() + out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle) + + def pytree_call(*args, **kwargs): + import torch.export._unlift + + flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs)) + if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec): + raise RuntimeError( + f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}" + ) + flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args)) + assert out_shuffle.out_spec is not None + return pytree.tree_unflatten( + out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec + ) + + if isinstance(mod, torch.nn.Module): + compiled_mod = copy.copy(mod) + compiled_mod.forward = types.MethodType(pytree_call, compiled_mod) + if not hasattr(compiled_mod, "meta"): + compiled_mod.meta = {} # type: ignore[attr-defined] + if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta: + compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode + return compiled_mod + elif inspect.ismethod(mod): + return types.MethodType(pytree_call, mod.__self__) + else: + return pytree_call + + +def dynamo_graph_capture_for_export( + mod: Callable[..., Any], +) -> Callable[..., Any]: + def inner(*args: Any, **kwargs: Any) -> Any: + with ( + get_metrics_context(), + dynamo_timed("fullgraph_capture"), + ): + out = fullgraph_capture(mod, args, kwargs) + + # TODO filter out side effects. + + return pytreeify(out, mod, args, kwargs) + + return inner + + def _dynamo_graph_capture_for_export( mod: Callable[..., Any], *, diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index f39d80f89b4..75b10aa4b6d 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -463,6 +463,7 @@ class OutputGraphCommon(OutputGraphGuardsState): def __init__( self, output_graph_guards_state: OutputGraphGuardsState, + import_sources: Optional[dict[str, str]] = None, shape_env: Optional[ShapeEnv] = None, export_metadata: Optional[ExportMetaData] = None, tracked_fakes_id_to_source: Optional[dict[int, list[Source]]] = None, @@ -485,6 +486,7 @@ class OutputGraphCommon(OutputGraphGuardsState): output_graph_guards_state.name_of_builtins_dict_key_in_fglobals, ) + self.import_sources = import_sources or {} # The following fields are currently known to be used by clients. # In particular, we need: # - shape_env, for building guards