[export] Update dynamo_graph_capture_for_export to return GraphModule. (#166091)

Make dynamo_graph_capture_for_export return a more compatible GraphModule object which is closer the the original behavior of dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166091
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen 2025-10-28 04:23:28 +00:00 committed by PyTorch MergeBot
parent a77f5d9a00
commit f93ea7dab1
10 changed files with 379 additions and 85 deletions

View File

@ -442,13 +442,22 @@ class DTensorExportTest(TestCase):
# Run model to verify it works # Run model to verify it works
output = model(*inputs) output = model(*inputs)
with torch._dynamo.config.patch(install_free_tensors=True): with torch._dynamo.config.patch(
install_free_tensors=(export_fn is _dynamo_graph_capture_for_export)
):
# TODO: switch to use the official graph_capture API once it is ready # TODO: switch to use the official graph_capture API once it is ready
gm = export_fn(model)(*inputs) gm = export_fn(model)(*inputs)
output_gm = gm(*inputs) output_gm = gm(*inputs)
self.assertEqual(output, output_gm) self.assertEqual(output, output_gm)
def test_flex_attention_dtensor_export(self): @parametrize(
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors_v2,
graph_capture_and_aot_export_joint_with_descriptors,
],
)
def test_flex_attention_dtensor_export(self, export_fn):
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
model = FlexAttentionModel(self.device_type) model = FlexAttentionModel(self.device_type)
@ -485,9 +494,7 @@ class DTensorExportTest(TestCase):
flex_kwargs = {"block_mask": block_mask} flex_kwargs = {"block_mask": block_mask}
joint_gm = graph_capture_and_aot_export_joint_with_descriptors( joint_gm = export_fn(tp_model, inputs, flex_kwargs)
tp_model, inputs, flex_kwargs
)
self.assertTrue( self.assertTrue(
_count_op(joint_gm, torch.ops.higher_order.flex_attention), _count_op(joint_gm, torch.ops.higher_order.flex_attention),

View File

@ -3,16 +3,19 @@
import copy import copy
import types import types
import unittest import unittest
from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
import torch._dynamo import torch._dynamo
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.test_case import run_tests, TestCase
from torch._functorch.aot_autograd import aot_export_module from torch._functorch.aot_autograd import aot_export_module
from torch.export import export from torch.export import export
from torch.export.experimental import _export_forward_backward, _sticky_export from torch.export.experimental import _export_forward_backward, _sticky_export
from torch.export.graph_signature import OutputKind from torch.export.graph_signature import OutputKind
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_utils import TEST_CUDA
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
@ -403,8 +406,6 @@ def forward(self, x):
self.assertEqual(res_export, res_eager) self.assertEqual(res_export, res_eager)
def test_dynamo_graph_capture(self): def test_dynamo_graph_capture(self):
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
class Foo(torch.nn.Module): class Foo(torch.nn.Module):
def forward(self, dct, lst, bleh): def forward(self, dct, lst, bleh):
x = dct["a"] * lst[1][0] x = dct["a"] * lst[1][0]
@ -439,6 +440,151 @@ def forward(self, x):
test_inputs = make_inputs() test_inputs = make_inputs()
self.assertEqual(gm(*test_inputs), foo(*test_inputs)) self.assertEqual(gm(*test_inputs), foo(*test_inputs))
def test_dynamo_graph_capture_custom_pytree_type(self):
import torch.utils._pytree as pytree
@dataclass
class Bar:
x: torch.Tensor
y: torch.Tensor
class Foo(torch.nn.Module):
def forward(self, bar: Bar):
return bar.x + bar.y
foo = Foo()
def make_inputs():
return (Bar(torch.randn(2, 3), torch.randn(2, 3)),)
pytree.register_dataclass(Bar)
try:
trace_inputs = make_inputs()
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
test_inputs = make_inputs()
self.assertExpectedInline(
gm._in_shuffle_graph.code.strip("\r\n "),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
return (arg1_1, arg2_1)""",
)
self.assertExpectedInline(
gm.code.strip("\r\n "),
"""\
def forward(self, args_0):
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0,))
L_bar_x , L_bar_y , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
l_bar_x = L_bar_x
l_bar_y = L_bar_y
add = l_bar_x + l_bar_y; l_bar_x = l_bar_y = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, add), self._out_spec)""",
)
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
finally:
pytree._deregister_pytree_node(Bar)
def test_dynamo_graph_capture_closure(self):
from torch.export import Dim
N = 3
outer = torch.randn(10, 32)
class MyModel(torch.nn.Module):
def forward(self, x):
z = x + outer
y = z[:-1, :] # [s0 - 1, 32]
stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32]
reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32]
return reshaped
inps = (torch.randn(10, 32),)
ep = dynamo_graph_capture_for_export(MyModel())(*inps)
self.assertExpectedInline(
ep._in_shuffle_graph.code.strip("\r\n "),
"""\
def forward(self, arg0_1, arg1_1):
_tensor_constant0 = self._tensor_constant0
return (arg1_1, _tensor_constant0)""",
)
self.assertExpectedInline(
ep.code.strip("\r\n "),
"""\
def forward(self, args_0):
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
L_x_ , L_outer_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
l_x_ = L_x_
l_outer_ = L_outer_
z = l_x_ + l_outer_; l_x_ = l_outer_ = None
y = z[(slice(None, -1, None), slice(None, None, None))]; z = None
stacked = torch.stack([y, y, y], dim = 0); y = None
reshaped = stacked.reshape(-1, 3, 32); stacked = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, reshaped), self._out_spec)""",
)
self.assertEqual(ep(*inps), MyModel()(*inps))
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
class DummyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scalar):
ctx.save_for_backward(x)
return x + scalar
@staticmethod
def backward(ctx, grad_out):
return grad_out, None
def mock_fw_compute(x):
with fx_traceback.annotate({"compute": 0}):
return DummyOp.apply(x, 10)
def mock_bw_comm(x):
with fx_traceback.annotate({"comm": 0}):
return DummyOp.apply(x, 20)
def mock_bw_compute(x):
return DummyOp.apply(x, 30)
class Model(torch.nn.Module):
def forward(self, fw_in, bw_in):
fw_out = mock_fw_compute(fw_in)
# bw_in blocks bw_out
bw_in = mock_bw_comm(bw_in)
bw_out = mock_bw_compute(bw_in)
return fw_out, bw_out
def input_fn():
inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),)
grad_ins = (torch.rand(2, 128, device="cuda"),)
return (
*inputs,
*grad_ins,
)
with torch.device("meta"):
model = Model()
import torch.fx.traceback as fx_traceback
with fx_traceback.preserve_node_meta():
gm = dynamo_graph_capture_for_export(model)(*input_fn())
"""
def forward(self, args_0, args_1):
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0, args_1,))
L_fw_in_ , L_bw_in_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
l_fw_in_ = L_fw_in_
l_bw_in_ = L_bw_in_
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
fw_out = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_fw_in_, args_tensor_mask = [True, False], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_fw_in_ = None
bw_in = l_bw_in_ + 20; l_bw_in_ = None
bw_out = bw_in + 30; bw_in = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, fw_out, bw_out), self._out_spec)
"""
test_inputs = input_fn()
self.assertEqual(gm(*test_inputs), model(*test_inputs))
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -896,6 +896,7 @@ class DynamoOutput:
output_graph.import_sources, output_graph.import_sources,
output_graph.traced_code, output_graph.traced_code,
self.bytecode, self.bytecode,
self.tracer_output.closure,
) )
@ -927,6 +928,7 @@ class GraphCaptureOutput:
import_sources: dict[str, str] import_sources: dict[str, str]
traced_code: list[CodeType] traced_code: list[CodeType]
bytecode: CodeType bytecode: CodeType
closure: Optional[tuple[Any, ...]]
def build_guards( def build_guards(
self, self,
@ -981,7 +983,7 @@ class CaptureOutput:
return types.FunctionType( return types.FunctionType(
self.graph_capture_output.bytecode, self.graph_capture_output.bytecode,
f_globals, f_globals,
closure=(), closure=self.graph_capture_output.closure,
) )

View File

@ -1211,7 +1211,9 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
# Make dynamo graph to have same input/output spec as user code # Make dynamo graph to have same input/output spec as user code
def argument_names( def argument_names(
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any] f_sig: inspect.Signature,
args: Union[list[Any], tuple[Any, ...]],
kwargs: dict[str, Any],
) -> list[str]: ) -> list[str]:
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec: def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
# Get a list of Parameter objects from the Signature object # Get a list of Parameter objects from the Signature object

View File

@ -1,10 +1,9 @@
import copy
import inspect import inspect
import logging import logging
import traceback import traceback
import types
from collections import namedtuple from collections import namedtuple
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING, Union from typing import Any, Optional, TYPE_CHECKING, Union
import sympy import sympy
@ -19,17 +18,19 @@ from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._export.utils import _compiling_state_context from torch._export.utils import _compiling_state_context
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
from torch.fx import Node from torch.fx import Node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError, ConstraintViolationError,
DimDynamic, DimDynamic,
ShapeEnv,
StatelessSymbolicContext, StatelessSymbolicContext,
) )
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.fx.graph import _ExportCodeGen, _PyTreeCodeGen, _PyTreeInfo
from torch.utils._pytree import TreeSpec
if TYPE_CHECKING: if TYPE_CHECKING:
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils._pytree import TreeSpec
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -448,9 +449,20 @@ def _suggest_or_raise_constraint_violation(
raise constraint_violation_error raise constraint_violation_error
@dataclass(frozen=True)
class PyTreeifyOutput:
graph_module: torch.fx.GraphModule
in_spec: TreeSpec
in_shuffle_graph: torch.fx.GraphModule
num_flat_args: int
out_spec: TreeSpec
out_shuffle_graph: torch.fx.GraphModule
root: Optional[torch.nn.Module] = None
def pytreeify( def pytreeify(
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any] out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Any: ) -> PyTreeifyOutput:
""" """
Given a dynamo capture output, return a callable graph module that Given a dynamo capture output, return a callable graph module that
contain the following information: contain the following information:
@ -468,10 +480,13 @@ def pytreeify(
backend_input = out.backend_input backend_input = out.backend_input
backend = out.backend_input.graph_module backend = out.backend_input.graph_module
root = None
if isinstance(mod, torch.nn.Module): if isinstance(mod, torch.nn.Module):
args = (mod,) + args args = (mod,) + args
root = mod
elif inspect.ismethod(mod): elif inspect.ismethod(mod):
args = (mod.__self__,) + args args = (mod.__self__,) + args
root = mod.__self__
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs)) flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
@ -504,15 +519,21 @@ def pytreeify(
backend_input.graph_module = backend backend_input.graph_module = backend
raise RuntimeError raise RuntimeError
in_shuffle_graph = torch.fx.symbolic_trace(InShuffle()) fake_mode = torch._dynamo.utils.detect_fake_mode(flat_real_args)
if fake_mode and fake_mode.shape_env is None:
fake_mode.shape_env = ShapeEnv()
in_shuffle_graph = make_fx(
InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True
)(*flat_real_args)
output_node = next(iter(reversed(backend_input.graph_module.graph.nodes)))
class OutShuffle(torch.nn.Module): class OutShuffle(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.num_inputs = len(flat_real_args) self.num_inputs = len(flat_real_args)
self.num_outputs = len(
next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0] self.num_outputs = len(output_node.args[0])
)
self.out_spec: Optional[TreeSpec] = None self.out_spec: Optional[TreeSpec] = None
def forward(self, *flat_proxy_args): def forward(self, *flat_proxy_args):
@ -535,49 +556,101 @@ def pytreeify(
return ret return ret
out_shuffle = OutShuffle() out_shuffle = OutShuffle()
out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle) flat_out_shuffle_args = [
*flat_real_args,
*pytree.tree_map_only(
torch.fx.Node,
lambda x: fake_mode.from_tensor(x.meta["example_value"])
if fake_mode
else x.meta["example_value"],
output_node.args[0],
),
]
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args)
if fake_mode and fake_mode.shape_env is None:
fake_mode.shape_env = ShapeEnv()
out_shuffle_graph = make_fx(
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
)(*flat_out_shuffle_args)
def pytree_call(*args, **kwargs): assert out_shuffle.out_spec is not None
import torch.export._unlift return PyTreeifyOutput(
backend_input.graph_module,
in_spec,
in_shuffle_graph,
len(flat_real_args),
out_shuffle.out_spec,
out_shuffle_graph,
root=root, # type: ignore[arg-type]
)
flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs))
if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec):
raise RuntimeError(
f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}"
)
flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args))
assert out_shuffle.out_spec is not None
return pytree.tree_unflatten(
out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec
)
if isinstance(mod, torch.nn.Module): def normalize_graph_module(gm):
compiled_mod = copy.copy(mod) for node in gm.graph.nodes:
compiled_mod.forward = types.MethodType(pytree_call, compiled_mod) if node.op == "placeholder":
if not hasattr(compiled_mod, "meta"): node.meta["val"] = node.meta["example_value"]
compiled_mod.meta = {} # type: ignore[attr-defined]
if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta:
compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode
return compiled_mod
elif inspect.ismethod(mod):
return types.MethodType(pytree_call, mod.__self__)
else:
return pytree_call
def dynamo_graph_capture_for_export( def dynamo_graph_capture_for_export(
mod: Callable[..., Any], mod: Callable[..., Any],
constraints: Optional[list[Constraint]] = None,
) -> Callable[..., Any]: ) -> Callable[..., Any]:
def inner(*args: Any, **kwargs: Any) -> Any: def inner(*args: Any, **kwargs: Any) -> Any:
assert not torch._dynamo.config.install_free_tensors
with ( with (
get_metrics_context(), get_metrics_context(),
dynamo_timed("fullgraph_capture"), dynamo_timed("fullgraph_capture"),
): ):
out = fullgraph_capture(mod, args, kwargs) out = fullgraph_capture(
mod,
args,
kwargs,
constraints=constraints,
)
# TODO filter out side effects. # TODO filter out side effects.
pyt = pytreeify(out, mod, args, kwargs)
return pytreeify(out, mod, args, kwargs) graph_module = pyt.graph_module
tree_leaf_names = [
graph_module.graph._graph_namespace.create_name(f"_tree_leaf_{i}", None)
for i in range(pyt.num_flat_args)
]
graph_module.graph._codegen = _ExportCodeGen(
_PyTreeInfo(
# TODO we should be able to use the names from dynamo graph directly.
argument_names(inspect.signature(mod), args, kwargs),
pyt.in_spec,
pyt.out_spec,
),
pyt.in_shuffle_graph,
pyt.out_shuffle_graph,
tree_leaf_names,
pyt.root,
) # type: ignore[attr-defined]
normalize_graph_module(graph_module)
if pyt.root is not None:
graph_module._parameters = pyt.root._parameters.copy()
graph_module._buffers = pyt.root._buffers.copy()
assert all(not hasattr(graph_module, m) for m in pyt.root._modules)
graph_module._modules.update(pyt.root._modules)
graph_module._non_persistent_buffers_set = (
pyt.root._non_persistent_buffers_set.copy()
)
graph_module._in_spec = pyt.in_spec
graph_module._out_spec = pyt.out_spec
assert not hasattr(graph_module, "_in_shuffle_graph")
assert not hasattr(graph_module, "_out_shuffle_graph")
graph_module._in_shuffle_graph = pyt.in_shuffle_graph
graph_module._out_shuffle_graph = pyt.out_shuffle_graph
delattr(graph_module, "_param_name_to_source")
graph_module.recompile()
graph_module.meta["module_call_specs"] = (
out.graph_capture_output.output_graph.export_metadata.module_call_spec
)
assert out.backend_input is not None
graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined]
return graph_module
return inner return inner

View File

@ -2747,12 +2747,14 @@ class DynamoTracerOutput:
error_on_graph_break: bool error_on_graph_break: bool
is_tracing_resume_prologue: bool is_tracing_resume_prologue: bool
output_graph: Optional[OutputGraph] output_graph: Optional[OutputGraph]
closure: Optional[tuple[Any, ...]]
def __init__( def __init__(
self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
) -> None: ) -> None:
self.error_on_graph_break = tracer.error_on_graph_break self.error_on_graph_break = tracer.error_on_graph_break
self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
self.closure = tracer.closure
if error: if error:
self.output_graph = None self.output_graph = None
else: else:

View File

@ -4217,6 +4217,7 @@ class InstructionTranslatorBase(
self.f_builtins: dict[str, Any] = f_builtins self.f_builtins: dict[str, Any] = f_builtins
self.code_options: dict[str, Any] = code_options self.code_options: dict[str, Any] = code_options
self.f_code: types.CodeType = f_code self.f_code: types.CodeType = f_code
self.closure = closure
# Execution record for replaying errors # Execution record for replaying errors
if closure is not None and config.replay_record_enabled: if closure is not None and config.replay_record_enabled:

View File

@ -97,7 +97,7 @@ from torch.fx.experimental.symbolic_shapes import (
GuardOnDataDependentSymNode, GuardOnDataDependentSymNode,
ShapeEnv, ShapeEnv,
) )
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.fx.graph import _PyTreeInfo
from torch.utils._pytree import TreeSpec from torch.utils._pytree import TreeSpec
from torch.utils._sympy.value_ranges import ValueRangeError from torch.utils._sympy.value_ranges import ValueRangeError
@ -1537,12 +1537,10 @@ def _strict_export(
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
gm_torch_level.graph._codegen = _PyTreeCodeGen( gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo(
_PyTreeInfo( orig_arg_names,
orig_arg_names, gm_torch_level._in_spec,
gm_torch_level._in_spec, out_spec,
out_spec,
)
) )
gm_torch_level.recompile() gm_torch_level.recompile()

View File

@ -1489,12 +1489,20 @@ def wrap_key(
@functools.wraps(f) @functools.wraps(f)
def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R: def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
nonlocal tensors
flat_proxies, _proxies_spec = pytree.tree_flatten(proxies) flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
assert len(flat_proxies) == len(flat_tensors) assert len(flat_proxies) == len(flat_tensors)
with disable_proxy_modes_tracing() as m: with disable_proxy_modes_tracing() as m:
assert isinstance(m, ProxyTorchDispatchMode) assert isinstance(m, ProxyTorchDispatchMode)
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
if getattr(tracer, "proxy_module_inputs", False):
tensors = [ # type: ignore[assignment, var-annotated]
p if isinstance(t, torch.nn.Module) else t
for t, p in zip(tensors, proxies) # type: ignore[arg-type]
]
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]: def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined] return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined]
@ -2208,6 +2216,7 @@ class _MakefxTracer:
_error_on_data_dependent_ops: bool, _error_on_data_dependent_ops: bool,
record_stack_traces: bool = False, record_stack_traces: bool = False,
parent_tracer: Optional[_MakefxTracer] = None, parent_tracer: Optional[_MakefxTracer] = None,
proxy_module_inputs: bool = False,
) -> None: ) -> None:
# Configurations that are used to initialize the context managers and their states. # Configurations that are used to initialize the context managers and their states.
# Should not modify them during tracing. # Should not modify them during tracing.
@ -2240,6 +2249,7 @@ class _MakefxTracer:
) )
self.record_stack_traces = record_stack_traces self.record_stack_traces = record_stack_traces
self.parent_tracer: Optional[_MakefxTracer] = parent_tracer self.parent_tracer: Optional[_MakefxTracer] = parent_tracer
self.proxy_module_inputs = proxy_module_inputs
def _checkpoint_modes(self) -> list[Any]: def _checkpoint_modes(self) -> list[Any]:
return [ return [
@ -2349,6 +2359,7 @@ class _MakefxTracer:
self.python_dispatcher_mode = enable_python_dispatcher() self.python_dispatcher_mode = enable_python_dispatcher()
self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer) self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer)
fx_tracer.proxy_module_inputs = self.proxy_module_inputs # type: ignore[union-attr]
@contextmanager @contextmanager
def _init_modes_from_parent( def _init_modes_from_parent(
@ -2551,6 +2562,7 @@ def make_fx(
_allow_fake_constant: bool = False, _allow_fake_constant: bool = False,
_error_on_data_dependent_ops: bool = True, _error_on_data_dependent_ops: bool = True,
record_stack_traces: bool = False, record_stack_traces: bool = False,
proxy_module_inputs: bool = False,
) -> Callable[..., GraphModule]: ) -> Callable[..., GraphModule]:
""" """
Given a function f, return a new function which when executed with valid Given a function f, return a new function which when executed with valid
@ -2574,6 +2586,7 @@ def make_fx(
_error_on_data_dependent_ops, _error_on_data_dependent_ops,
record_stack_traces=record_stack_traces record_stack_traces=record_stack_traces
or config.trace.provenance_tracking_level == 1, or config.trace.provenance_tracking_level == 1,
proxy_module_inputs=proxy_module_inputs,
) )
@functools.wraps(f) @functools.wraps(f)

View File

@ -930,6 +930,42 @@ class _PyTreeCodeGen(CodeGen):
else: else:
return "\n " + "".join(x + "; " for x in has_annotation) + "\n" return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
# when kwargs is present, in_spec is tuple(args, kwargs)
has_args_kwargs_tuple = (
self.pytree_info.in_spec.type is tuple
and self.pytree_info.in_spec.num_children == 2
and self.pytree_info.in_spec.children_specs[0].type is tuple
and self.pytree_info.in_spec.children_specs[1].type is dict
)
fn_kwargs = "{}"
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
if has_args_kwargs_tuple:
count_args = self.pytree_info.in_spec.children_specs[0].num_children
fn_args = self.pytree_info.orig_args[:count_args]
fn_kwargs = (
"{"
+ ", ".join(
f"'{k}':{v}"
for k, v in zip(
self.pytree_info.in_spec.children_specs[1].context,
self.pytree_info.orig_args[count_args:],
)
)
+ "}"
)
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
# in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
# we need to split it to two lines:
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
# one for code: `var1, var2, = function_call()`
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
bindings = self._format_annotations(free_vars, expanded_def)
bindings += f"""
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
return bindings
def gen_fn_def( def gen_fn_def(
self, free_vars, maybe_return_annotation, *, expanded_def: bool = False self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
): ):
@ -962,39 +998,7 @@ class _PyTreeCodeGen(CodeGen):
) )
if len(free_vars) > 0: # pytree has placeholders in it if len(free_vars) > 0: # pytree has placeholders in it
# when kwargs is present, in_spec is tuple(args, kwargs) fn_definition += self.gen_var_bindings(fn_args, free_vars, expanded_def)
has_args_kwargs_tuple = (
self.pytree_info.in_spec.type is tuple
and self.pytree_info.in_spec.num_children == 2
and self.pytree_info.in_spec.children_specs[0].type is tuple
and self.pytree_info.in_spec.children_specs[1].type is dict
)
fn_kwargs = "{}"
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
if has_args_kwargs_tuple:
count_args = self.pytree_info.in_spec.children_specs[0].num_children
fn_args = self.pytree_info.orig_args[:count_args]
fn_kwargs = (
"{"
+ ", ".join(
f"'{k}':{v}"
for k, v in zip(
self.pytree_info.in_spec.children_specs[1].context,
self.pytree_info.orig_args[count_args:],
)
)
+ "}"
)
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
# in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
# we need to split it to two lines:
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
# one for code: `var1, var2, = function_call()`
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
fn_definition += self._format_annotations(free_vars, expanded_def)
fn_definition += f"""
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
return fn_definition return fn_definition
def generate_output(self, output_args, *, descs: Optional[Any] = None): def generate_output(self, output_args, *, descs: Optional[Any] = None):
@ -1014,6 +1018,52 @@ class _PyTreeCodeGen(CodeGen):
return super().generate_output(output_args, descs=descs) return super().generate_output(output_args, descs=descs)
class _ExportCodeGen(_PyTreeCodeGen):
def __init__(
self,
pytree_info: _PyTreeInfo,
in_shuffle_graph: "GraphModule",
out_shuffle_graph: "GraphModule",
tree_leaf_names: list[str],
root: Optional[torch.nn.Module],
):
super().__init__(pytree_info)
self.in_shuffle_graph = in_shuffle_graph
self.out_shuffle_graph = out_shuffle_graph
self.tree_leaf_names = tree_leaf_names
self.root = root
def process_inputs(self, *inputs: Any) -> Any:
flat_args = super().process_inputs(*inputs)
if self.root is not None:
flat_args = (self.root, *flat_args)
self.flat_args = flat_args
return self.in_shuffle_graph(*flat_args)
def process_outputs(self, out: Any) -> Any:
flat_outs = self.out_shuffle_graph(*self.flat_args, *out)
del self.flat_args
ret = super().process_outputs(flat_outs)
return ret
def gen_fn_def(self, *args, **kwargs) -> str:
fn_def = super().gen_fn_def(*args, **kwargs)
return fn_def
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
fn_signature: str = f"{', '.join(fn_args)}"
if self.root is not None:
fn_signature = f"self, {fn_signature}"
return f"""
{", ".join(self.tree_leaf_names)}, = pytree.tree_leaves(({fn_signature},))
{", ".join(without_annotation)}, = self._in_shuffle_graph({", ".join(self.tree_leaf_names)})"""
def generate_output(self, output_args, *args, **kwargs) -> str:
output = f"self._out_shuffle_graph({', '.join(self.tree_leaf_names)}, {', '.join([str(a) for a in output_args])})"
return f"return pytree.tree_unflatten({output}, self._out_spec)"
class _FindNodesLookupTable: class _FindNodesLookupTable:
""" """
Side table for the graph for the purpose of doing fast queries Side table for the graph for the purpose of doing fast queries