diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index b04c9324cb9..5152c8956f5 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -442,13 +442,22 @@ class DTensorExportTest(TestCase): # Run model to verify it works output = model(*inputs) - with torch._dynamo.config.patch(install_free_tensors=True): + with torch._dynamo.config.patch( + install_free_tensors=(export_fn is _dynamo_graph_capture_for_export) + ): # TODO: switch to use the official graph_capture API once it is ready gm = export_fn(model)(*inputs) output_gm = gm(*inputs) self.assertEqual(output, output_gm) - def test_flex_attention_dtensor_export(self): + @parametrize( + "export_fn", + [ + graph_capture_and_aot_export_joint_with_descriptors_v2, + graph_capture_and_aot_export_joint_with_descriptors, + ], + ) + def test_flex_attention_dtensor_export(self, export_fn): device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) model = FlexAttentionModel(self.device_type) @@ -485,9 +494,7 @@ class DTensorExportTest(TestCase): flex_kwargs = {"block_mask": block_mask} - joint_gm = graph_capture_and_aot_export_joint_with_descriptors( - tp_model, inputs, flex_kwargs - ) + joint_gm = export_fn(tp_model, inputs, flex_kwargs) self.assertTrue( _count_op(joint_gm, torch.ops.higher_order.flex_attention), diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 2f4a370ac79..a1cc8856810 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -3,16 +3,19 @@ import copy import types import unittest +from dataclasses import dataclass from typing import Dict, List, Tuple import torch import torch._dynamo +from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._dynamo.test_case import run_tests, TestCase from torch._functorch.aot_autograd import aot_export_module from torch.export import export from torch.export.experimental import _export_forward_backward, _sticky_export from torch.export.graph_signature import OutputKind from torch.testing import FileCheck +from torch.testing._internal.common_utils import TEST_CUDA @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") @@ -403,8 +406,6 @@ def forward(self, x): self.assertEqual(res_export, res_eager) def test_dynamo_graph_capture(self): - from torch._dynamo.functional_export import dynamo_graph_capture_for_export - class Foo(torch.nn.Module): def forward(self, dct, lst, bleh): x = dct["a"] * lst[1][0] @@ -439,6 +440,151 @@ def forward(self, x): test_inputs = make_inputs() self.assertEqual(gm(*test_inputs), foo(*test_inputs)) + def test_dynamo_graph_capture_custom_pytree_type(self): + import torch.utils._pytree as pytree + + @dataclass + class Bar: + x: torch.Tensor + y: torch.Tensor + + class Foo(torch.nn.Module): + def forward(self, bar: Bar): + return bar.x + bar.y + + foo = Foo() + + def make_inputs(): + return (Bar(torch.randn(2, 3), torch.randn(2, 3)),) + + pytree.register_dataclass(Bar) + try: + trace_inputs = make_inputs() + gm = dynamo_graph_capture_for_export(foo)(*trace_inputs) + test_inputs = make_inputs() + self.assertExpectedInline( + gm._in_shuffle_graph.code.strip("\r\n "), + """\ +def forward(self, arg0_1, arg1_1, arg2_1): + return (arg1_1, arg2_1)""", + ) + self.assertExpectedInline( + gm.code.strip("\r\n "), + """\ +def forward(self, args_0): + _tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0,)) + L_bar_x , L_bar_y , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2) + l_bar_x = L_bar_x + l_bar_y = L_bar_y + add = l_bar_x + l_bar_y; l_bar_x = l_bar_y = None + return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, add), self._out_spec)""", + ) + self.assertEqual(gm(*test_inputs), foo(*test_inputs)) + finally: + pytree._deregister_pytree_node(Bar) + + def test_dynamo_graph_capture_closure(self): + from torch.export import Dim + + N = 3 + outer = torch.randn(10, 32) + + class MyModel(torch.nn.Module): + def forward(self, x): + z = x + outer + y = z[:-1, :] # [s0 - 1, 32] + stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32] + reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32] + return reshaped + + inps = (torch.randn(10, 32),) + ep = dynamo_graph_capture_for_export(MyModel())(*inps) + self.assertExpectedInline( + ep._in_shuffle_graph.code.strip("\r\n "), + """\ +def forward(self, arg0_1, arg1_1): + _tensor_constant0 = self._tensor_constant0 + return (arg1_1, _tensor_constant0)""", + ) + self.assertExpectedInline( + ep.code.strip("\r\n "), + """\ +def forward(self, args_0): + _tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,)) + L_x_ , L_outer_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1) + l_x_ = L_x_ + l_outer_ = L_outer_ + z = l_x_ + l_outer_; l_x_ = l_outer_ = None + y = z[(slice(None, -1, None), slice(None, None, None))]; z = None + stacked = torch.stack([y, y, y], dim = 0); y = None + reshaped = stacked.reshape(-1, 3, 32); stacked = None + return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, reshaped), self._out_spec)""", + ) + self.assertEqual(ep(*inps), MyModel()(*inps)) + + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self): + class DummyOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scalar): + ctx.save_for_backward(x) + return x + scalar + + @staticmethod + def backward(ctx, grad_out): + return grad_out, None + + def mock_fw_compute(x): + with fx_traceback.annotate({"compute": 0}): + return DummyOp.apply(x, 10) + + def mock_bw_comm(x): + with fx_traceback.annotate({"comm": 0}): + return DummyOp.apply(x, 20) + + def mock_bw_compute(x): + return DummyOp.apply(x, 30) + + class Model(torch.nn.Module): + def forward(self, fw_in, bw_in): + fw_out = mock_fw_compute(fw_in) + # bw_in blocks bw_out + bw_in = mock_bw_comm(bw_in) + bw_out = mock_bw_compute(bw_in) + return fw_out, bw_out + + def input_fn(): + inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),) + grad_ins = (torch.rand(2, 128, device="cuda"),) + return ( + *inputs, + *grad_ins, + ) + + with torch.device("meta"): + model = Model() + + import torch.fx.traceback as fx_traceback + + with fx_traceback.preserve_node_meta(): + gm = dynamo_graph_capture_for_export(model)(*input_fn()) + + """ + def forward(self, args_0, args_1): + _tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0, args_1,)) + L_fw_in_ , L_bw_in_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2) + l_fw_in_ = L_fw_in_ + l_bw_in_ = L_bw_in_ + fwd_body_0 = self.fwd_body_0 + bwd_body_0 = self.bwd_body_0 + fw_out = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_fw_in_, args_tensor_mask = [True, False], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_fw_in_ = None + bw_in = l_bw_in_ + 20; l_bw_in_ = None + bw_out = bw_in + 30; bw_in = None + return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, fw_out, bw_out), self._out_spec) + """ + test_inputs = input_fn() + self.assertEqual(gm(*test_inputs), model(*test_inputs)) + if __name__ == "__main__": run_tests() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index e123e6d6d60..74a53c6d9c4 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -896,6 +896,7 @@ class DynamoOutput: output_graph.import_sources, output_graph.traced_code, self.bytecode, + self.tracer_output.closure, ) @@ -927,6 +928,7 @@ class GraphCaptureOutput: import_sources: dict[str, str] traced_code: list[CodeType] bytecode: CodeType + closure: Optional[tuple[Any, ...]] def build_guards( self, @@ -981,7 +983,7 @@ class CaptureOutput: return types.FunctionType( self.graph_capture_output.bytecode, f_globals, - closure=(), + closure=self.graph_capture_output.closure, ) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index a7edabe0daa..fd41236ef9b 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1211,7 +1211,9 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] # Make dynamo graph to have same input/output spec as user code def argument_names( - f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any] + f_sig: inspect.Signature, + args: Union[list[Any], tuple[Any, ...]], + kwargs: dict[str, Any], ) -> list[str]: def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec: # Get a list of Parameter objects from the Signature object diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 28801b99b82..d7c4f6a3001 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -1,10 +1,9 @@ -import copy import inspect import logging import traceback -import types from collections import namedtuple from collections.abc import Callable +from dataclasses import dataclass from typing import Any, Optional, TYPE_CHECKING, Union import sympy @@ -19,17 +18,19 @@ from torch._dynamo.utils import dynamo_timed, get_metrics_context from torch._export.utils import _compiling_state_context from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint from torch.fx import Node +from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, + ShapeEnv, StatelessSymbolicContext, ) -from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.graph import _ExportCodeGen, _PyTreeCodeGen, _PyTreeInfo +from torch.utils._pytree import TreeSpec if TYPE_CHECKING: from torch._subclasses.fake_tensor import FakeTensorMode - from torch.utils._pytree import TreeSpec log = logging.getLogger(__name__) @@ -448,9 +449,20 @@ def _suggest_or_raise_constraint_violation( raise constraint_violation_error +@dataclass(frozen=True) +class PyTreeifyOutput: + graph_module: torch.fx.GraphModule + in_spec: TreeSpec + in_shuffle_graph: torch.fx.GraphModule + num_flat_args: int + out_spec: TreeSpec + out_shuffle_graph: torch.fx.GraphModule + root: Optional[torch.nn.Module] = None + + def pytreeify( out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any] -) -> Any: +) -> PyTreeifyOutput: """ Given a dynamo capture output, return a callable graph module that contain the following information: @@ -468,10 +480,13 @@ def pytreeify( backend_input = out.backend_input backend = out.backend_input.graph_module + root = None if isinstance(mod, torch.nn.Module): args = (mod,) + args + root = mod elif inspect.ismethod(mod): args = (mod.__self__,) + args + root = mod.__self__ flat_real_args, in_spec = pytree.tree_flatten((args, kwargs)) @@ -504,15 +519,21 @@ def pytreeify( backend_input.graph_module = backend raise RuntimeError - in_shuffle_graph = torch.fx.symbolic_trace(InShuffle()) + fake_mode = torch._dynamo.utils.detect_fake_mode(flat_real_args) + if fake_mode and fake_mode.shape_env is None: + fake_mode.shape_env = ShapeEnv() + in_shuffle_graph = make_fx( + InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True + )(*flat_real_args) + + output_node = next(iter(reversed(backend_input.graph_module.graph.nodes))) class OutShuffle(torch.nn.Module): def __init__(self): super().__init__() self.num_inputs = len(flat_real_args) - self.num_outputs = len( - next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0] - ) + + self.num_outputs = len(output_node.args[0]) self.out_spec: Optional[TreeSpec] = None def forward(self, *flat_proxy_args): @@ -535,49 +556,101 @@ def pytreeify( return ret out_shuffle = OutShuffle() - out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle) + flat_out_shuffle_args = [ + *flat_real_args, + *pytree.tree_map_only( + torch.fx.Node, + lambda x: fake_mode.from_tensor(x.meta["example_value"]) + if fake_mode + else x.meta["example_value"], + output_node.args[0], + ), + ] + fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args) + if fake_mode and fake_mode.shape_env is None: + fake_mode.shape_env = ShapeEnv() + out_shuffle_graph = make_fx( + out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True + )(*flat_out_shuffle_args) - def pytree_call(*args, **kwargs): - import torch.export._unlift + assert out_shuffle.out_spec is not None + return PyTreeifyOutput( + backend_input.graph_module, + in_spec, + in_shuffle_graph, + len(flat_real_args), + out_shuffle.out_spec, + out_shuffle_graph, + root=root, # type: ignore[arg-type] + ) - flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs)) - if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec): - raise RuntimeError( - f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}" - ) - flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args)) - assert out_shuffle.out_spec is not None - return pytree.tree_unflatten( - out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec - ) - if isinstance(mod, torch.nn.Module): - compiled_mod = copy.copy(mod) - compiled_mod.forward = types.MethodType(pytree_call, compiled_mod) - if not hasattr(compiled_mod, "meta"): - compiled_mod.meta = {} # type: ignore[attr-defined] - if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta: - compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode - return compiled_mod - elif inspect.ismethod(mod): - return types.MethodType(pytree_call, mod.__self__) - else: - return pytree_call +def normalize_graph_module(gm): + for node in gm.graph.nodes: + if node.op == "placeholder": + node.meta["val"] = node.meta["example_value"] def dynamo_graph_capture_for_export( mod: Callable[..., Any], + constraints: Optional[list[Constraint]] = None, ) -> Callable[..., Any]: def inner(*args: Any, **kwargs: Any) -> Any: + assert not torch._dynamo.config.install_free_tensors with ( get_metrics_context(), dynamo_timed("fullgraph_capture"), ): - out = fullgraph_capture(mod, args, kwargs) + out = fullgraph_capture( + mod, + args, + kwargs, + constraints=constraints, + ) # TODO filter out side effects. + pyt = pytreeify(out, mod, args, kwargs) - return pytreeify(out, mod, args, kwargs) + graph_module = pyt.graph_module + tree_leaf_names = [ + graph_module.graph._graph_namespace.create_name(f"_tree_leaf_{i}", None) + for i in range(pyt.num_flat_args) + ] + graph_module.graph._codegen = _ExportCodeGen( + _PyTreeInfo( + # TODO we should be able to use the names from dynamo graph directly. + argument_names(inspect.signature(mod), args, kwargs), + pyt.in_spec, + pyt.out_spec, + ), + pyt.in_shuffle_graph, + pyt.out_shuffle_graph, + tree_leaf_names, + pyt.root, + ) # type: ignore[attr-defined] + normalize_graph_module(graph_module) + if pyt.root is not None: + graph_module._parameters = pyt.root._parameters.copy() + graph_module._buffers = pyt.root._buffers.copy() + assert all(not hasattr(graph_module, m) for m in pyt.root._modules) + graph_module._modules.update(pyt.root._modules) + graph_module._non_persistent_buffers_set = ( + pyt.root._non_persistent_buffers_set.copy() + ) + graph_module._in_spec = pyt.in_spec + graph_module._out_spec = pyt.out_spec + assert not hasattr(graph_module, "_in_shuffle_graph") + assert not hasattr(graph_module, "_out_shuffle_graph") + graph_module._in_shuffle_graph = pyt.in_shuffle_graph + graph_module._out_shuffle_graph = pyt.out_shuffle_graph + delattr(graph_module, "_param_name_to_source") + graph_module.recompile() + graph_module.meta["module_call_specs"] = ( + out.graph_capture_output.output_graph.export_metadata.module_call_spec + ) + assert out.backend_input is not None + graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined] + return graph_module return inner diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 961de577fbf..50638ccbba0 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -2747,12 +2747,14 @@ class DynamoTracerOutput: error_on_graph_break: bool is_tracing_resume_prologue: bool output_graph: Optional[OutputGraph] + closure: Optional[tuple[Any, ...]] def __init__( self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None ) -> None: self.error_on_graph_break = tracer.error_on_graph_break self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue + self.closure = tracer.closure if error: self.output_graph = None else: diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 9a30b729db0..d5cff162c8e 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -4217,6 +4217,7 @@ class InstructionTranslatorBase( self.f_builtins: dict[str, Any] = f_builtins self.code_options: dict[str, Any] = code_options self.f_code: types.CodeType = f_code + self.closure = closure # Execution record for replaying errors if closure is not None and config.replay_record_enabled: diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 480490fd1a0..8e310833e71 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -97,7 +97,7 @@ from torch.fx.experimental.symbolic_shapes import ( GuardOnDataDependentSymNode, ShapeEnv, ) -from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.graph import _PyTreeInfo from torch.utils._pytree import TreeSpec from torch.utils._sympy.value_ranges import ValueRangeError @@ -1537,12 +1537,10 @@ def _strict_export( orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] - gm_torch_level.graph._codegen = _PyTreeCodeGen( - _PyTreeInfo( - orig_arg_names, - gm_torch_level._in_spec, - out_spec, - ) + gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo( + orig_arg_names, + gm_torch_level._in_spec, + out_spec, ) gm_torch_level.recompile() diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 3d3dd1cb22c..dd5fab2b257 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1489,12 +1489,20 @@ def wrap_key( @functools.wraps(f) def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R: + nonlocal tensors + flat_proxies, _proxies_spec = pytree.tree_flatten(proxies) assert len(flat_proxies) == len(flat_tensors) with disable_proxy_modes_tracing() as m: assert isinstance(m, ProxyTorchDispatchMode) track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) + if getattr(tracer, "proxy_module_inputs", False): + tensors = [ # type: ignore[assignment, var-annotated] + p if isinstance(t, torch.nn.Module) else t + for t, p in zip(tensors, proxies) # type: ignore[arg-type] + ] + def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]: return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined] @@ -2208,6 +2216,7 @@ class _MakefxTracer: _error_on_data_dependent_ops: bool, record_stack_traces: bool = False, parent_tracer: Optional[_MakefxTracer] = None, + proxy_module_inputs: bool = False, ) -> None: # Configurations that are used to initialize the context managers and their states. # Should not modify them during tracing. @@ -2240,6 +2249,7 @@ class _MakefxTracer: ) self.record_stack_traces = record_stack_traces self.parent_tracer: Optional[_MakefxTracer] = parent_tracer + self.proxy_module_inputs = proxy_module_inputs def _checkpoint_modes(self) -> list[Any]: return [ @@ -2349,6 +2359,7 @@ class _MakefxTracer: self.python_dispatcher_mode = enable_python_dispatcher() self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer) + fx_tracer.proxy_module_inputs = self.proxy_module_inputs # type: ignore[union-attr] @contextmanager def _init_modes_from_parent( @@ -2551,6 +2562,7 @@ def make_fx( _allow_fake_constant: bool = False, _error_on_data_dependent_ops: bool = True, record_stack_traces: bool = False, + proxy_module_inputs: bool = False, ) -> Callable[..., GraphModule]: """ Given a function f, return a new function which when executed with valid @@ -2574,6 +2586,7 @@ def make_fx( _error_on_data_dependent_ops, record_stack_traces=record_stack_traces or config.trace.provenance_tracking_level == 1, + proxy_module_inputs=proxy_module_inputs, ) @functools.wraps(f) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index f2ecc3e917f..3435695c14d 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -930,6 +930,42 @@ class _PyTreeCodeGen(CodeGen): else: return "\n " + "".join(x + "; " for x in has_annotation) + "\n" + def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: + # when kwargs is present, in_spec is tuple(args, kwargs) + has_args_kwargs_tuple = ( + self.pytree_info.in_spec.type is tuple + and self.pytree_info.in_spec.num_children == 2 + and self.pytree_info.in_spec.children_specs[0].type is tuple + and self.pytree_info.in_spec.children_specs[1].type is dict + ) + fn_kwargs = "{}" + fn_signature = f"[{', '.join(fn_args)}], self._in_spec" + if has_args_kwargs_tuple: + count_args = self.pytree_info.in_spec.children_specs[0].num_children + fn_args = self.pytree_info.orig_args[:count_args] + fn_kwargs = ( + "{" + + ", ".join( + f"'{k}':{v}" + for k, v in zip( + self.pytree_info.in_spec.children_specs[1].context, + self.pytree_info.orig_args[count_args:], + ) + ) + + "}" + ) + fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" + + # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. + # we need to split it to two lines: + # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) + # one for code: `var1, var2, = function_call()` + without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars] + bindings = self._format_annotations(free_vars, expanded_def) + bindings += f""" + {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + return bindings + def gen_fn_def( self, free_vars, maybe_return_annotation, *, expanded_def: bool = False ): @@ -962,39 +998,7 @@ class _PyTreeCodeGen(CodeGen): ) if len(free_vars) > 0: # pytree has placeholders in it - # when kwargs is present, in_spec is tuple(args, kwargs) - has_args_kwargs_tuple = ( - self.pytree_info.in_spec.type is tuple - and self.pytree_info.in_spec.num_children == 2 - and self.pytree_info.in_spec.children_specs[0].type is tuple - and self.pytree_info.in_spec.children_specs[1].type is dict - ) - fn_kwargs = "{}" - fn_signature = f"[{', '.join(fn_args)}], self._in_spec" - if has_args_kwargs_tuple: - count_args = self.pytree_info.in_spec.children_specs[0].num_children - fn_args = self.pytree_info.orig_args[:count_args] - fn_kwargs = ( - "{" - + ", ".join( - f"'{k}':{v}" - for k, v in zip( - self.pytree_info.in_spec.children_specs[1].context, - self.pytree_info.orig_args[count_args:], - ) - ) - + "}" - ) - fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" - - # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. - # we need to split it to two lines: - # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) - # one for code: `var1, var2, = function_call()` - without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars] - fn_definition += self._format_annotations(free_vars, expanded_def) - fn_definition += f""" - {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + fn_definition += self.gen_var_bindings(fn_args, free_vars, expanded_def) return fn_definition def generate_output(self, output_args, *, descs: Optional[Any] = None): @@ -1014,6 +1018,52 @@ class _PyTreeCodeGen(CodeGen): return super().generate_output(output_args, descs=descs) +class _ExportCodeGen(_PyTreeCodeGen): + def __init__( + self, + pytree_info: _PyTreeInfo, + in_shuffle_graph: "GraphModule", + out_shuffle_graph: "GraphModule", + tree_leaf_names: list[str], + root: Optional[torch.nn.Module], + ): + super().__init__(pytree_info) + self.in_shuffle_graph = in_shuffle_graph + self.out_shuffle_graph = out_shuffle_graph + self.tree_leaf_names = tree_leaf_names + self.root = root + + def process_inputs(self, *inputs: Any) -> Any: + flat_args = super().process_inputs(*inputs) + if self.root is not None: + flat_args = (self.root, *flat_args) + self.flat_args = flat_args + return self.in_shuffle_graph(*flat_args) + + def process_outputs(self, out: Any) -> Any: + flat_outs = self.out_shuffle_graph(*self.flat_args, *out) + del self.flat_args + ret = super().process_outputs(flat_outs) + return ret + + def gen_fn_def(self, *args, **kwargs) -> str: + fn_def = super().gen_fn_def(*args, **kwargs) + return fn_def + + def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: + without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars] + fn_signature: str = f"{', '.join(fn_args)}" + if self.root is not None: + fn_signature = f"self, {fn_signature}" + return f""" + {", ".join(self.tree_leaf_names)}, = pytree.tree_leaves(({fn_signature},)) + {", ".join(without_annotation)}, = self._in_shuffle_graph({", ".join(self.tree_leaf_names)})""" + + def generate_output(self, output_args, *args, **kwargs) -> str: + output = f"self._out_shuffle_graph({', '.join(self.tree_leaf_names)}, {', '.join([str(a) for a in output_args])})" + return f"return pytree.tree_unflatten({output}, self._out_spec)" + + class _FindNodesLookupTable: """ Side table for the graph for the purpose of doing fast queries