mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[export] Update dynamo_graph_capture_for_export to return GraphModule. (#166091)
Make dynamo_graph_capture_for_export return a more compatible GraphModule object which is closer the the original behavior of dynamo Pull Request resolved: https://github.com/pytorch/pytorch/pull/166091 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
parent
a77f5d9a00
commit
f93ea7dab1
|
|
@ -442,13 +442,22 @@ class DTensorExportTest(TestCase):
|
|||
|
||||
# Run model to verify it works
|
||||
output = model(*inputs)
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
with torch._dynamo.config.patch(
|
||||
install_free_tensors=(export_fn is _dynamo_graph_capture_for_export)
|
||||
):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = export_fn(model)(*inputs)
|
||||
output_gm = gm(*inputs)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
def test_flex_attention_dtensor_export(self):
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
],
|
||||
)
|
||||
def test_flex_attention_dtensor_export(self, export_fn):
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
model = FlexAttentionModel(self.device_type)
|
||||
|
||||
|
|
@ -485,9 +494,7 @@ class DTensorExportTest(TestCase):
|
|||
|
||||
flex_kwargs = {"block_mask": block_mask}
|
||||
|
||||
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(
|
||||
tp_model, inputs, flex_kwargs
|
||||
)
|
||||
joint_gm = export_fn(tp_model, inputs, flex_kwargs)
|
||||
|
||||
self.assertTrue(
|
||||
_count_op(joint_gm, torch.ops.higher_order.flex_attention),
|
||||
|
|
|
|||
|
|
@ -3,16 +3,19 @@
|
|||
import copy
|
||||
import types
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch.export import export
|
||||
from torch.export.experimental import _export_forward_backward, _sticky_export
|
||||
from torch.export.graph_signature import OutputKind
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import TEST_CUDA
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
||||
|
|
@ -403,8 +406,6 @@ def forward(self, x):
|
|||
self.assertEqual(res_export, res_eager)
|
||||
|
||||
def test_dynamo_graph_capture(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, dct, lst, bleh):
|
||||
x = dct["a"] * lst[1][0]
|
||||
|
|
@ -439,6 +440,151 @@ def forward(self, x):
|
|||
test_inputs = make_inputs()
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
|
||||
def test_dynamo_graph_capture_custom_pytree_type(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
@dataclass
|
||||
class Bar:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, bar: Bar):
|
||||
return bar.x + bar.y
|
||||
|
||||
foo = Foo()
|
||||
|
||||
def make_inputs():
|
||||
return (Bar(torch.randn(2, 3), torch.randn(2, 3)),)
|
||||
|
||||
pytree.register_dataclass(Bar)
|
||||
try:
|
||||
trace_inputs = make_inputs()
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
test_inputs = make_inputs()
|
||||
self.assertExpectedInline(
|
||||
gm._in_shuffle_graph.code.strip("\r\n "),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
return (arg1_1, arg2_1)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip("\r\n "),
|
||||
"""\
|
||||
def forward(self, args_0):
|
||||
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0,))
|
||||
L_bar_x , L_bar_y , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
|
||||
l_bar_x = L_bar_x
|
||||
l_bar_y = L_bar_y
|
||||
add = l_bar_x + l_bar_y; l_bar_x = l_bar_y = None
|
||||
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, add), self._out_spec)""",
|
||||
)
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
finally:
|
||||
pytree._deregister_pytree_node(Bar)
|
||||
|
||||
def test_dynamo_graph_capture_closure(self):
|
||||
from torch.export import Dim
|
||||
|
||||
N = 3
|
||||
outer = torch.randn(10, 32)
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
z = x + outer
|
||||
y = z[:-1, :] # [s0 - 1, 32]
|
||||
stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32]
|
||||
reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32]
|
||||
return reshaped
|
||||
|
||||
inps = (torch.randn(10, 32),)
|
||||
ep = dynamo_graph_capture_for_export(MyModel())(*inps)
|
||||
self.assertExpectedInline(
|
||||
ep._in_shuffle_graph.code.strip("\r\n "),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
return (arg1_1, _tensor_constant0)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
ep.code.strip("\r\n "),
|
||||
"""\
|
||||
def forward(self, args_0):
|
||||
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
|
||||
L_x_ , L_outer_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
|
||||
l_x_ = L_x_
|
||||
l_outer_ = L_outer_
|
||||
z = l_x_ + l_outer_; l_x_ = l_outer_ = None
|
||||
y = z[(slice(None, -1, None), slice(None, None, None))]; z = None
|
||||
stacked = torch.stack([y, y, y], dim = 0); y = None
|
||||
reshaped = stacked.reshape(-1, 3, 32); stacked = None
|
||||
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, reshaped), self._out_spec)""",
|
||||
)
|
||||
self.assertEqual(ep(*inps), MyModel()(*inps))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
|
||||
class DummyOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scalar):
|
||||
ctx.save_for_backward(x)
|
||||
return x + scalar
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return grad_out, None
|
||||
|
||||
def mock_fw_compute(x):
|
||||
with fx_traceback.annotate({"compute": 0}):
|
||||
return DummyOp.apply(x, 10)
|
||||
|
||||
def mock_bw_comm(x):
|
||||
with fx_traceback.annotate({"comm": 0}):
|
||||
return DummyOp.apply(x, 20)
|
||||
|
||||
def mock_bw_compute(x):
|
||||
return DummyOp.apply(x, 30)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, fw_in, bw_in):
|
||||
fw_out = mock_fw_compute(fw_in)
|
||||
# bw_in blocks bw_out
|
||||
bw_in = mock_bw_comm(bw_in)
|
||||
bw_out = mock_bw_compute(bw_in)
|
||||
return fw_out, bw_out
|
||||
|
||||
def input_fn():
|
||||
inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),)
|
||||
grad_ins = (torch.rand(2, 128, device="cuda"),)
|
||||
return (
|
||||
*inputs,
|
||||
*grad_ins,
|
||||
)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = Model()
|
||||
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
with fx_traceback.preserve_node_meta():
|
||||
gm = dynamo_graph_capture_for_export(model)(*input_fn())
|
||||
|
||||
"""
|
||||
def forward(self, args_0, args_1):
|
||||
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0, args_1,))
|
||||
L_fw_in_ , L_bw_in_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
|
||||
l_fw_in_ = L_fw_in_
|
||||
l_bw_in_ = L_bw_in_
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
fw_out = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_fw_in_, args_tensor_mask = [True, False], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_fw_in_ = None
|
||||
bw_in = l_bw_in_ + 20; l_bw_in_ = None
|
||||
bw_out = bw_in + 30; bw_in = None
|
||||
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, fw_out, bw_out), self._out_spec)
|
||||
"""
|
||||
test_inputs = input_fn()
|
||||
self.assertEqual(gm(*test_inputs), model(*test_inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -896,6 +896,7 @@ class DynamoOutput:
|
|||
output_graph.import_sources,
|
||||
output_graph.traced_code,
|
||||
self.bytecode,
|
||||
self.tracer_output.closure,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -927,6 +928,7 @@ class GraphCaptureOutput:
|
|||
import_sources: dict[str, str]
|
||||
traced_code: list[CodeType]
|
||||
bytecode: CodeType
|
||||
closure: Optional[tuple[Any, ...]]
|
||||
|
||||
def build_guards(
|
||||
self,
|
||||
|
|
@ -981,7 +983,7 @@ class CaptureOutput:
|
|||
return types.FunctionType(
|
||||
self.graph_capture_output.bytecode,
|
||||
f_globals,
|
||||
closure=(),
|
||||
closure=self.graph_capture_output.closure,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1211,7 +1211,9 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
|
|||
|
||||
# Make dynamo graph to have same input/output spec as user code
|
||||
def argument_names(
|
||||
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
|
||||
f_sig: inspect.Signature,
|
||||
args: Union[list[Any], tuple[Any, ...]],
|
||||
kwargs: dict[str, Any],
|
||||
) -> list[str]:
|
||||
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
|
||||
# Get a list of Parameter objects from the Signature object
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import traceback
|
||||
import types
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
|
@ -19,17 +18,19 @@ from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
|||
from torch._export.utils import _compiling_state_context
|
||||
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
|
||||
from torch.fx import Node
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
DimDynamic,
|
||||
ShapeEnv,
|
||||
StatelessSymbolicContext,
|
||||
)
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
from torch.fx.graph import _ExportCodeGen, _PyTreeCodeGen, _PyTreeInfo
|
||||
from torch.utils._pytree import TreeSpec
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.utils._pytree import TreeSpec
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -448,9 +449,20 @@ def _suggest_or_raise_constraint_violation(
|
|||
raise constraint_violation_error
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PyTreeifyOutput:
|
||||
graph_module: torch.fx.GraphModule
|
||||
in_spec: TreeSpec
|
||||
in_shuffle_graph: torch.fx.GraphModule
|
||||
num_flat_args: int
|
||||
out_spec: TreeSpec
|
||||
out_shuffle_graph: torch.fx.GraphModule
|
||||
root: Optional[torch.nn.Module] = None
|
||||
|
||||
|
||||
def pytreeify(
|
||||
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
) -> PyTreeifyOutput:
|
||||
"""
|
||||
Given a dynamo capture output, return a callable graph module that
|
||||
contain the following information:
|
||||
|
|
@ -468,10 +480,13 @@ def pytreeify(
|
|||
backend_input = out.backend_input
|
||||
backend = out.backend_input.graph_module
|
||||
|
||||
root = None
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
args = (mod,) + args
|
||||
root = mod
|
||||
elif inspect.ismethod(mod):
|
||||
args = (mod.__self__,) + args
|
||||
root = mod.__self__
|
||||
|
||||
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
|
|
@ -504,15 +519,21 @@ def pytreeify(
|
|||
backend_input.graph_module = backend
|
||||
raise RuntimeError
|
||||
|
||||
in_shuffle_graph = torch.fx.symbolic_trace(InShuffle())
|
||||
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_real_args)
|
||||
if fake_mode and fake_mode.shape_env is None:
|
||||
fake_mode.shape_env = ShapeEnv()
|
||||
in_shuffle_graph = make_fx(
|
||||
InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True
|
||||
)(*flat_real_args)
|
||||
|
||||
output_node = next(iter(reversed(backend_input.graph_module.graph.nodes)))
|
||||
|
||||
class OutShuffle(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_inputs = len(flat_real_args)
|
||||
self.num_outputs = len(
|
||||
next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0]
|
||||
)
|
||||
|
||||
self.num_outputs = len(output_node.args[0])
|
||||
self.out_spec: Optional[TreeSpec] = None
|
||||
|
||||
def forward(self, *flat_proxy_args):
|
||||
|
|
@ -535,49 +556,101 @@ def pytreeify(
|
|||
return ret
|
||||
|
||||
out_shuffle = OutShuffle()
|
||||
out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle)
|
||||
flat_out_shuffle_args = [
|
||||
*flat_real_args,
|
||||
*pytree.tree_map_only(
|
||||
torch.fx.Node,
|
||||
lambda x: fake_mode.from_tensor(x.meta["example_value"])
|
||||
if fake_mode
|
||||
else x.meta["example_value"],
|
||||
output_node.args[0],
|
||||
),
|
||||
]
|
||||
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args)
|
||||
if fake_mode and fake_mode.shape_env is None:
|
||||
fake_mode.shape_env = ShapeEnv()
|
||||
out_shuffle_graph = make_fx(
|
||||
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
|
||||
)(*flat_out_shuffle_args)
|
||||
|
||||
def pytree_call(*args, **kwargs):
|
||||
import torch.export._unlift
|
||||
assert out_shuffle.out_spec is not None
|
||||
return PyTreeifyOutput(
|
||||
backend_input.graph_module,
|
||||
in_spec,
|
||||
in_shuffle_graph,
|
||||
len(flat_real_args),
|
||||
out_shuffle.out_spec,
|
||||
out_shuffle_graph,
|
||||
root=root, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs))
|
||||
if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec):
|
||||
raise RuntimeError(
|
||||
f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}"
|
||||
)
|
||||
flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args))
|
||||
assert out_shuffle.out_spec is not None
|
||||
return pytree.tree_unflatten(
|
||||
out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec
|
||||
)
|
||||
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
compiled_mod = copy.copy(mod)
|
||||
compiled_mod.forward = types.MethodType(pytree_call, compiled_mod)
|
||||
if not hasattr(compiled_mod, "meta"):
|
||||
compiled_mod.meta = {} # type: ignore[attr-defined]
|
||||
if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta:
|
||||
compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode
|
||||
return compiled_mod
|
||||
elif inspect.ismethod(mod):
|
||||
return types.MethodType(pytree_call, mod.__self__)
|
||||
else:
|
||||
return pytree_call
|
||||
def normalize_graph_module(gm):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
node.meta["val"] = node.meta["example_value"]
|
||||
|
||||
|
||||
def dynamo_graph_capture_for_export(
|
||||
mod: Callable[..., Any],
|
||||
constraints: Optional[list[Constraint]] = None,
|
||||
) -> Callable[..., Any]:
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
assert not torch._dynamo.config.install_free_tensors
|
||||
with (
|
||||
get_metrics_context(),
|
||||
dynamo_timed("fullgraph_capture"),
|
||||
):
|
||||
out = fullgraph_capture(mod, args, kwargs)
|
||||
out = fullgraph_capture(
|
||||
mod,
|
||||
args,
|
||||
kwargs,
|
||||
constraints=constraints,
|
||||
)
|
||||
|
||||
# TODO filter out side effects.
|
||||
pyt = pytreeify(out, mod, args, kwargs)
|
||||
|
||||
return pytreeify(out, mod, args, kwargs)
|
||||
graph_module = pyt.graph_module
|
||||
tree_leaf_names = [
|
||||
graph_module.graph._graph_namespace.create_name(f"_tree_leaf_{i}", None)
|
||||
for i in range(pyt.num_flat_args)
|
||||
]
|
||||
graph_module.graph._codegen = _ExportCodeGen(
|
||||
_PyTreeInfo(
|
||||
# TODO we should be able to use the names from dynamo graph directly.
|
||||
argument_names(inspect.signature(mod), args, kwargs),
|
||||
pyt.in_spec,
|
||||
pyt.out_spec,
|
||||
),
|
||||
pyt.in_shuffle_graph,
|
||||
pyt.out_shuffle_graph,
|
||||
tree_leaf_names,
|
||||
pyt.root,
|
||||
) # type: ignore[attr-defined]
|
||||
normalize_graph_module(graph_module)
|
||||
if pyt.root is not None:
|
||||
graph_module._parameters = pyt.root._parameters.copy()
|
||||
graph_module._buffers = pyt.root._buffers.copy()
|
||||
assert all(not hasattr(graph_module, m) for m in pyt.root._modules)
|
||||
graph_module._modules.update(pyt.root._modules)
|
||||
graph_module._non_persistent_buffers_set = (
|
||||
pyt.root._non_persistent_buffers_set.copy()
|
||||
)
|
||||
graph_module._in_spec = pyt.in_spec
|
||||
graph_module._out_spec = pyt.out_spec
|
||||
assert not hasattr(graph_module, "_in_shuffle_graph")
|
||||
assert not hasattr(graph_module, "_out_shuffle_graph")
|
||||
graph_module._in_shuffle_graph = pyt.in_shuffle_graph
|
||||
graph_module._out_shuffle_graph = pyt.out_shuffle_graph
|
||||
delattr(graph_module, "_param_name_to_source")
|
||||
graph_module.recompile()
|
||||
graph_module.meta["module_call_specs"] = (
|
||||
out.graph_capture_output.output_graph.export_metadata.module_call_spec
|
||||
)
|
||||
assert out.backend_input is not None
|
||||
graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined]
|
||||
return graph_module
|
||||
|
||||
return inner
|
||||
|
||||
|
|
|
|||
|
|
@ -2747,12 +2747,14 @@ class DynamoTracerOutput:
|
|||
error_on_graph_break: bool
|
||||
is_tracing_resume_prologue: bool
|
||||
output_graph: Optional[OutputGraph]
|
||||
closure: Optional[tuple[Any, ...]]
|
||||
|
||||
def __init__(
|
||||
self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
|
||||
) -> None:
|
||||
self.error_on_graph_break = tracer.error_on_graph_break
|
||||
self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
|
||||
self.closure = tracer.closure
|
||||
if error:
|
||||
self.output_graph = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4217,6 +4217,7 @@ class InstructionTranslatorBase(
|
|||
self.f_builtins: dict[str, Any] = f_builtins
|
||||
self.code_options: dict[str, Any] = code_options
|
||||
self.f_code: types.CodeType = f_code
|
||||
self.closure = closure
|
||||
|
||||
# Execution record for replaying errors
|
||||
if closure is not None and config.replay_record_enabled:
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
GuardOnDataDependentSymNode,
|
||||
ShapeEnv,
|
||||
)
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
from torch.fx.graph import _PyTreeInfo
|
||||
from torch.utils._pytree import TreeSpec
|
||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||
|
||||
|
|
@ -1537,12 +1537,10 @@ def _strict_export(
|
|||
|
||||
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
|
||||
|
||||
gm_torch_level.graph._codegen = _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
orig_arg_names,
|
||||
gm_torch_level._in_spec,
|
||||
out_spec,
|
||||
)
|
||||
gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo(
|
||||
orig_arg_names,
|
||||
gm_torch_level._in_spec,
|
||||
out_spec,
|
||||
)
|
||||
gm_torch_level.recompile()
|
||||
|
||||
|
|
|
|||
|
|
@ -1489,12 +1489,20 @@ def wrap_key(
|
|||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
|
||||
nonlocal tensors
|
||||
|
||||
flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
|
||||
assert len(flat_proxies) == len(flat_tensors)
|
||||
with disable_proxy_modes_tracing() as m:
|
||||
assert isinstance(m, ProxyTorchDispatchMode)
|
||||
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
|
||||
|
||||
if getattr(tracer, "proxy_module_inputs", False):
|
||||
tensors = [ # type: ignore[assignment, var-annotated]
|
||||
p if isinstance(t, torch.nn.Module) else t
|
||||
for t, p in zip(tensors, proxies) # type: ignore[arg-type]
|
||||
]
|
||||
|
||||
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
|
||||
return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined]
|
||||
|
||||
|
|
@ -2208,6 +2216,7 @@ class _MakefxTracer:
|
|||
_error_on_data_dependent_ops: bool,
|
||||
record_stack_traces: bool = False,
|
||||
parent_tracer: Optional[_MakefxTracer] = None,
|
||||
proxy_module_inputs: bool = False,
|
||||
) -> None:
|
||||
# Configurations that are used to initialize the context managers and their states.
|
||||
# Should not modify them during tracing.
|
||||
|
|
@ -2240,6 +2249,7 @@ class _MakefxTracer:
|
|||
)
|
||||
self.record_stack_traces = record_stack_traces
|
||||
self.parent_tracer: Optional[_MakefxTracer] = parent_tracer
|
||||
self.proxy_module_inputs = proxy_module_inputs
|
||||
|
||||
def _checkpoint_modes(self) -> list[Any]:
|
||||
return [
|
||||
|
|
@ -2349,6 +2359,7 @@ class _MakefxTracer:
|
|||
self.python_dispatcher_mode = enable_python_dispatcher()
|
||||
|
||||
self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer)
|
||||
fx_tracer.proxy_module_inputs = self.proxy_module_inputs # type: ignore[union-attr]
|
||||
|
||||
@contextmanager
|
||||
def _init_modes_from_parent(
|
||||
|
|
@ -2551,6 +2562,7 @@ def make_fx(
|
|||
_allow_fake_constant: bool = False,
|
||||
_error_on_data_dependent_ops: bool = True,
|
||||
record_stack_traces: bool = False,
|
||||
proxy_module_inputs: bool = False,
|
||||
) -> Callable[..., GraphModule]:
|
||||
"""
|
||||
Given a function f, return a new function which when executed with valid
|
||||
|
|
@ -2574,6 +2586,7 @@ def make_fx(
|
|||
_error_on_data_dependent_ops,
|
||||
record_stack_traces=record_stack_traces
|
||||
or config.trace.provenance_tracking_level == 1,
|
||||
proxy_module_inputs=proxy_module_inputs,
|
||||
)
|
||||
|
||||
@functools.wraps(f)
|
||||
|
|
|
|||
|
|
@ -930,6 +930,42 @@ class _PyTreeCodeGen(CodeGen):
|
|||
else:
|
||||
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
|
||||
|
||||
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
|
||||
# when kwargs is present, in_spec is tuple(args, kwargs)
|
||||
has_args_kwargs_tuple = (
|
||||
self.pytree_info.in_spec.type is tuple
|
||||
and self.pytree_info.in_spec.num_children == 2
|
||||
and self.pytree_info.in_spec.children_specs[0].type is tuple
|
||||
and self.pytree_info.in_spec.children_specs[1].type is dict
|
||||
)
|
||||
fn_kwargs = "{}"
|
||||
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
|
||||
if has_args_kwargs_tuple:
|
||||
count_args = self.pytree_info.in_spec.children_specs[0].num_children
|
||||
fn_args = self.pytree_info.orig_args[:count_args]
|
||||
fn_kwargs = (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f"'{k}':{v}"
|
||||
for k, v in zip(
|
||||
self.pytree_info.in_spec.children_specs[1].context,
|
||||
self.pytree_info.orig_args[count_args:],
|
||||
)
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
|
||||
|
||||
# in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
|
||||
# we need to split it to two lines:
|
||||
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
|
||||
# one for code: `var1, var2, = function_call()`
|
||||
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
|
||||
bindings = self._format_annotations(free_vars, expanded_def)
|
||||
bindings += f"""
|
||||
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
|
||||
return bindings
|
||||
|
||||
def gen_fn_def(
|
||||
self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
|
||||
):
|
||||
|
|
@ -962,39 +998,7 @@ class _PyTreeCodeGen(CodeGen):
|
|||
)
|
||||
|
||||
if len(free_vars) > 0: # pytree has placeholders in it
|
||||
# when kwargs is present, in_spec is tuple(args, kwargs)
|
||||
has_args_kwargs_tuple = (
|
||||
self.pytree_info.in_spec.type is tuple
|
||||
and self.pytree_info.in_spec.num_children == 2
|
||||
and self.pytree_info.in_spec.children_specs[0].type is tuple
|
||||
and self.pytree_info.in_spec.children_specs[1].type is dict
|
||||
)
|
||||
fn_kwargs = "{}"
|
||||
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
|
||||
if has_args_kwargs_tuple:
|
||||
count_args = self.pytree_info.in_spec.children_specs[0].num_children
|
||||
fn_args = self.pytree_info.orig_args[:count_args]
|
||||
fn_kwargs = (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f"'{k}':{v}"
|
||||
for k, v in zip(
|
||||
self.pytree_info.in_spec.children_specs[1].context,
|
||||
self.pytree_info.orig_args[count_args:],
|
||||
)
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
|
||||
|
||||
# in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
|
||||
# we need to split it to two lines:
|
||||
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
|
||||
# one for code: `var1, var2, = function_call()`
|
||||
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
|
||||
fn_definition += self._format_annotations(free_vars, expanded_def)
|
||||
fn_definition += f"""
|
||||
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
|
||||
fn_definition += self.gen_var_bindings(fn_args, free_vars, expanded_def)
|
||||
return fn_definition
|
||||
|
||||
def generate_output(self, output_args, *, descs: Optional[Any] = None):
|
||||
|
|
@ -1014,6 +1018,52 @@ class _PyTreeCodeGen(CodeGen):
|
|||
return super().generate_output(output_args, descs=descs)
|
||||
|
||||
|
||||
class _ExportCodeGen(_PyTreeCodeGen):
|
||||
def __init__(
|
||||
self,
|
||||
pytree_info: _PyTreeInfo,
|
||||
in_shuffle_graph: "GraphModule",
|
||||
out_shuffle_graph: "GraphModule",
|
||||
tree_leaf_names: list[str],
|
||||
root: Optional[torch.nn.Module],
|
||||
):
|
||||
super().__init__(pytree_info)
|
||||
self.in_shuffle_graph = in_shuffle_graph
|
||||
self.out_shuffle_graph = out_shuffle_graph
|
||||
self.tree_leaf_names = tree_leaf_names
|
||||
self.root = root
|
||||
|
||||
def process_inputs(self, *inputs: Any) -> Any:
|
||||
flat_args = super().process_inputs(*inputs)
|
||||
if self.root is not None:
|
||||
flat_args = (self.root, *flat_args)
|
||||
self.flat_args = flat_args
|
||||
return self.in_shuffle_graph(*flat_args)
|
||||
|
||||
def process_outputs(self, out: Any) -> Any:
|
||||
flat_outs = self.out_shuffle_graph(*self.flat_args, *out)
|
||||
del self.flat_args
|
||||
ret = super().process_outputs(flat_outs)
|
||||
return ret
|
||||
|
||||
def gen_fn_def(self, *args, **kwargs) -> str:
|
||||
fn_def = super().gen_fn_def(*args, **kwargs)
|
||||
return fn_def
|
||||
|
||||
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
|
||||
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
|
||||
fn_signature: str = f"{', '.join(fn_args)}"
|
||||
if self.root is not None:
|
||||
fn_signature = f"self, {fn_signature}"
|
||||
return f"""
|
||||
{", ".join(self.tree_leaf_names)}, = pytree.tree_leaves(({fn_signature},))
|
||||
{", ".join(without_annotation)}, = self._in_shuffle_graph({", ".join(self.tree_leaf_names)})"""
|
||||
|
||||
def generate_output(self, output_args, *args, **kwargs) -> str:
|
||||
output = f"self._out_shuffle_graph({', '.join(self.tree_leaf_names)}, {', '.join([str(a) for a in output_args])})"
|
||||
return f"return pytree.tree_unflatten({output}, self._out_spec)"
|
||||
|
||||
|
||||
class _FindNodesLookupTable:
|
||||
"""
|
||||
Side table for the graph for the purpose of doing fast queries
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user