mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Update dynamo_graph_capture_for_export to return GraphModule. (#166091)
Make dynamo_graph_capture_for_export return a more compatible GraphModule object which is closer the the original behavior of dynamo Pull Request resolved: https://github.com/pytorch/pytorch/pull/166091 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
parent
a77f5d9a00
commit
f93ea7dab1
|
|
@ -442,13 +442,22 @@ class DTensorExportTest(TestCase):
|
||||||
|
|
||||||
# Run model to verify it works
|
# Run model to verify it works
|
||||||
output = model(*inputs)
|
output = model(*inputs)
|
||||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
with torch._dynamo.config.patch(
|
||||||
|
install_free_tensors=(export_fn is _dynamo_graph_capture_for_export)
|
||||||
|
):
|
||||||
# TODO: switch to use the official graph_capture API once it is ready
|
# TODO: switch to use the official graph_capture API once it is ready
|
||||||
gm = export_fn(model)(*inputs)
|
gm = export_fn(model)(*inputs)
|
||||||
output_gm = gm(*inputs)
|
output_gm = gm(*inputs)
|
||||||
self.assertEqual(output, output_gm)
|
self.assertEqual(output, output_gm)
|
||||||
|
|
||||||
def test_flex_attention_dtensor_export(self):
|
@parametrize(
|
||||||
|
"export_fn",
|
||||||
|
[
|
||||||
|
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||||
|
graph_capture_and_aot_export_joint_with_descriptors,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_flex_attention_dtensor_export(self, export_fn):
|
||||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||||
model = FlexAttentionModel(self.device_type)
|
model = FlexAttentionModel(self.device_type)
|
||||||
|
|
||||||
|
|
@ -485,9 +494,7 @@ class DTensorExportTest(TestCase):
|
||||||
|
|
||||||
flex_kwargs = {"block_mask": block_mask}
|
flex_kwargs = {"block_mask": block_mask}
|
||||||
|
|
||||||
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(
|
joint_gm = export_fn(tp_model, inputs, flex_kwargs)
|
||||||
tp_model, inputs, flex_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
_count_op(joint_gm, torch.ops.higher_order.flex_attention),
|
_count_op(joint_gm, torch.ops.higher_order.flex_attention),
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,19 @@
|
||||||
import copy
|
import copy
|
||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
|
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||||
from torch._dynamo.test_case import run_tests, TestCase
|
from torch._dynamo.test_case import run_tests, TestCase
|
||||||
from torch._functorch.aot_autograd import aot_export_module
|
from torch._functorch.aot_autograd import aot_export_module
|
||||||
from torch.export import export
|
from torch.export import export
|
||||||
from torch.export.experimental import _export_forward_backward, _sticky_export
|
from torch.export.experimental import _export_forward_backward, _sticky_export
|
||||||
from torch.export.graph_signature import OutputKind
|
from torch.export.graph_signature import OutputKind
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
|
from torch.testing._internal.common_utils import TEST_CUDA
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
||||||
|
|
@ -403,8 +406,6 @@ def forward(self, x):
|
||||||
self.assertEqual(res_export, res_eager)
|
self.assertEqual(res_export, res_eager)
|
||||||
|
|
||||||
def test_dynamo_graph_capture(self):
|
def test_dynamo_graph_capture(self):
|
||||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
|
||||||
|
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def forward(self, dct, lst, bleh):
|
def forward(self, dct, lst, bleh):
|
||||||
x = dct["a"] * lst[1][0]
|
x = dct["a"] * lst[1][0]
|
||||||
|
|
@ -439,6 +440,151 @@ def forward(self, x):
|
||||||
test_inputs = make_inputs()
|
test_inputs = make_inputs()
|
||||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||||
|
|
||||||
|
def test_dynamo_graph_capture_custom_pytree_type(self):
|
||||||
|
import torch.utils._pytree as pytree
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Bar:
|
||||||
|
x: torch.Tensor
|
||||||
|
y: torch.Tensor
|
||||||
|
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, bar: Bar):
|
||||||
|
return bar.x + bar.y
|
||||||
|
|
||||||
|
foo = Foo()
|
||||||
|
|
||||||
|
def make_inputs():
|
||||||
|
return (Bar(torch.randn(2, 3), torch.randn(2, 3)),)
|
||||||
|
|
||||||
|
pytree.register_dataclass(Bar)
|
||||||
|
try:
|
||||||
|
trace_inputs = make_inputs()
|
||||||
|
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||||
|
test_inputs = make_inputs()
|
||||||
|
self.assertExpectedInline(
|
||||||
|
gm._in_shuffle_graph.code.strip("\r\n "),
|
||||||
|
"""\
|
||||||
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
|
return (arg1_1, arg2_1)""",
|
||||||
|
)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
gm.code.strip("\r\n "),
|
||||||
|
"""\
|
||||||
|
def forward(self, args_0):
|
||||||
|
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0,))
|
||||||
|
L_bar_x , L_bar_y , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
|
||||||
|
l_bar_x = L_bar_x
|
||||||
|
l_bar_y = L_bar_y
|
||||||
|
add = l_bar_x + l_bar_y; l_bar_x = l_bar_y = None
|
||||||
|
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, add), self._out_spec)""",
|
||||||
|
)
|
||||||
|
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||||
|
finally:
|
||||||
|
pytree._deregister_pytree_node(Bar)
|
||||||
|
|
||||||
|
def test_dynamo_graph_capture_closure(self):
|
||||||
|
from torch.export import Dim
|
||||||
|
|
||||||
|
N = 3
|
||||||
|
outer = torch.randn(10, 32)
|
||||||
|
|
||||||
|
class MyModel(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
z = x + outer
|
||||||
|
y = z[:-1, :] # [s0 - 1, 32]
|
||||||
|
stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32]
|
||||||
|
reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32]
|
||||||
|
return reshaped
|
||||||
|
|
||||||
|
inps = (torch.randn(10, 32),)
|
||||||
|
ep = dynamo_graph_capture_for_export(MyModel())(*inps)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
ep._in_shuffle_graph.code.strip("\r\n "),
|
||||||
|
"""\
|
||||||
|
def forward(self, arg0_1, arg1_1):
|
||||||
|
_tensor_constant0 = self._tensor_constant0
|
||||||
|
return (arg1_1, _tensor_constant0)""",
|
||||||
|
)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
ep.code.strip("\r\n "),
|
||||||
|
"""\
|
||||||
|
def forward(self, args_0):
|
||||||
|
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
|
||||||
|
L_x_ , L_outer_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
|
||||||
|
l_x_ = L_x_
|
||||||
|
l_outer_ = L_outer_
|
||||||
|
z = l_x_ + l_outer_; l_x_ = l_outer_ = None
|
||||||
|
y = z[(slice(None, -1, None), slice(None, None, None))]; z = None
|
||||||
|
stacked = torch.stack([y, y, y], dim = 0); y = None
|
||||||
|
reshaped = stacked.reshape(-1, 3, 32); stacked = None
|
||||||
|
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, reshaped), self._out_spec)""",
|
||||||
|
)
|
||||||
|
self.assertEqual(ep(*inps), MyModel()(*inps))
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||||
|
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
|
||||||
|
class DummyOp(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, scalar):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return x + scalar
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
return grad_out, None
|
||||||
|
|
||||||
|
def mock_fw_compute(x):
|
||||||
|
with fx_traceback.annotate({"compute": 0}):
|
||||||
|
return DummyOp.apply(x, 10)
|
||||||
|
|
||||||
|
def mock_bw_comm(x):
|
||||||
|
with fx_traceback.annotate({"comm": 0}):
|
||||||
|
return DummyOp.apply(x, 20)
|
||||||
|
|
||||||
|
def mock_bw_compute(x):
|
||||||
|
return DummyOp.apply(x, 30)
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def forward(self, fw_in, bw_in):
|
||||||
|
fw_out = mock_fw_compute(fw_in)
|
||||||
|
# bw_in blocks bw_out
|
||||||
|
bw_in = mock_bw_comm(bw_in)
|
||||||
|
bw_out = mock_bw_compute(bw_in)
|
||||||
|
return fw_out, bw_out
|
||||||
|
|
||||||
|
def input_fn():
|
||||||
|
inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),)
|
||||||
|
grad_ins = (torch.rand(2, 128, device="cuda"),)
|
||||||
|
return (
|
||||||
|
*inputs,
|
||||||
|
*grad_ins,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = Model()
|
||||||
|
|
||||||
|
import torch.fx.traceback as fx_traceback
|
||||||
|
|
||||||
|
with fx_traceback.preserve_node_meta():
|
||||||
|
gm = dynamo_graph_capture_for_export(model)(*input_fn())
|
||||||
|
|
||||||
|
"""
|
||||||
|
def forward(self, args_0, args_1):
|
||||||
|
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0, args_1,))
|
||||||
|
L_fw_in_ , L_bw_in_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
|
||||||
|
l_fw_in_ = L_fw_in_
|
||||||
|
l_bw_in_ = L_bw_in_
|
||||||
|
fwd_body_0 = self.fwd_body_0
|
||||||
|
bwd_body_0 = self.bwd_body_0
|
||||||
|
fw_out = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_fw_in_, args_tensor_mask = [True, False], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_fw_in_ = None
|
||||||
|
bw_in = l_bw_in_ + 20; l_bw_in_ = None
|
||||||
|
bw_out = bw_in + 30; bw_in = None
|
||||||
|
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, fw_out, bw_out), self._out_spec)
|
||||||
|
"""
|
||||||
|
test_inputs = input_fn()
|
||||||
|
self.assertEqual(gm(*test_inputs), model(*test_inputs))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -896,6 +896,7 @@ class DynamoOutput:
|
||||||
output_graph.import_sources,
|
output_graph.import_sources,
|
||||||
output_graph.traced_code,
|
output_graph.traced_code,
|
||||||
self.bytecode,
|
self.bytecode,
|
||||||
|
self.tracer_output.closure,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -927,6 +928,7 @@ class GraphCaptureOutput:
|
||||||
import_sources: dict[str, str]
|
import_sources: dict[str, str]
|
||||||
traced_code: list[CodeType]
|
traced_code: list[CodeType]
|
||||||
bytecode: CodeType
|
bytecode: CodeType
|
||||||
|
closure: Optional[tuple[Any, ...]]
|
||||||
|
|
||||||
def build_guards(
|
def build_guards(
|
||||||
self,
|
self,
|
||||||
|
|
@ -981,7 +983,7 @@ class CaptureOutput:
|
||||||
return types.FunctionType(
|
return types.FunctionType(
|
||||||
self.graph_capture_output.bytecode,
|
self.graph_capture_output.bytecode,
|
||||||
f_globals,
|
f_globals,
|
||||||
closure=(),
|
closure=self.graph_capture_output.closure,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1211,7 +1211,9 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
|
||||||
|
|
||||||
# Make dynamo graph to have same input/output spec as user code
|
# Make dynamo graph to have same input/output spec as user code
|
||||||
def argument_names(
|
def argument_names(
|
||||||
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
|
f_sig: inspect.Signature,
|
||||||
|
args: Union[list[Any], tuple[Any, ...]],
|
||||||
|
kwargs: dict[str, Any],
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
|
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
|
||||||
# Get a list of Parameter objects from the Signature object
|
# Get a list of Parameter objects from the Signature object
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
import copy
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
@ -19,17 +18,19 @@ from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||||
from torch._export.utils import _compiling_state_context
|
from torch._export.utils import _compiling_state_context
|
||||||
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
|
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
ConstraintViolationError,
|
ConstraintViolationError,
|
||||||
DimDynamic,
|
DimDynamic,
|
||||||
|
ShapeEnv,
|
||||||
StatelessSymbolicContext,
|
StatelessSymbolicContext,
|
||||||
)
|
)
|
||||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
from torch.fx.graph import _ExportCodeGen, _PyTreeCodeGen, _PyTreeInfo
|
||||||
|
from torch.utils._pytree import TreeSpec
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
from torch.utils._pytree import TreeSpec
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -448,9 +449,20 @@ def _suggest_or_raise_constraint_violation(
|
||||||
raise constraint_violation_error
|
raise constraint_violation_error
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PyTreeifyOutput:
|
||||||
|
graph_module: torch.fx.GraphModule
|
||||||
|
in_spec: TreeSpec
|
||||||
|
in_shuffle_graph: torch.fx.GraphModule
|
||||||
|
num_flat_args: int
|
||||||
|
out_spec: TreeSpec
|
||||||
|
out_shuffle_graph: torch.fx.GraphModule
|
||||||
|
root: Optional[torch.nn.Module] = None
|
||||||
|
|
||||||
|
|
||||||
def pytreeify(
|
def pytreeify(
|
||||||
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
|
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> PyTreeifyOutput:
|
||||||
"""
|
"""
|
||||||
Given a dynamo capture output, return a callable graph module that
|
Given a dynamo capture output, return a callable graph module that
|
||||||
contain the following information:
|
contain the following information:
|
||||||
|
|
@ -468,10 +480,13 @@ def pytreeify(
|
||||||
backend_input = out.backend_input
|
backend_input = out.backend_input
|
||||||
backend = out.backend_input.graph_module
|
backend = out.backend_input.graph_module
|
||||||
|
|
||||||
|
root = None
|
||||||
if isinstance(mod, torch.nn.Module):
|
if isinstance(mod, torch.nn.Module):
|
||||||
args = (mod,) + args
|
args = (mod,) + args
|
||||||
|
root = mod
|
||||||
elif inspect.ismethod(mod):
|
elif inspect.ismethod(mod):
|
||||||
args = (mod.__self__,) + args
|
args = (mod.__self__,) + args
|
||||||
|
root = mod.__self__
|
||||||
|
|
||||||
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
|
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
|
||||||
|
|
||||||
|
|
@ -504,15 +519,21 @@ def pytreeify(
|
||||||
backend_input.graph_module = backend
|
backend_input.graph_module = backend
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
in_shuffle_graph = torch.fx.symbolic_trace(InShuffle())
|
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_real_args)
|
||||||
|
if fake_mode and fake_mode.shape_env is None:
|
||||||
|
fake_mode.shape_env = ShapeEnv()
|
||||||
|
in_shuffle_graph = make_fx(
|
||||||
|
InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True
|
||||||
|
)(*flat_real_args)
|
||||||
|
|
||||||
|
output_node = next(iter(reversed(backend_input.graph_module.graph.nodes)))
|
||||||
|
|
||||||
class OutShuffle(torch.nn.Module):
|
class OutShuffle(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_inputs = len(flat_real_args)
|
self.num_inputs = len(flat_real_args)
|
||||||
self.num_outputs = len(
|
|
||||||
next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0]
|
self.num_outputs = len(output_node.args[0])
|
||||||
)
|
|
||||||
self.out_spec: Optional[TreeSpec] = None
|
self.out_spec: Optional[TreeSpec] = None
|
||||||
|
|
||||||
def forward(self, *flat_proxy_args):
|
def forward(self, *flat_proxy_args):
|
||||||
|
|
@ -535,49 +556,101 @@ def pytreeify(
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
out_shuffle = OutShuffle()
|
out_shuffle = OutShuffle()
|
||||||
out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle)
|
flat_out_shuffle_args = [
|
||||||
|
*flat_real_args,
|
||||||
|
*pytree.tree_map_only(
|
||||||
|
torch.fx.Node,
|
||||||
|
lambda x: fake_mode.from_tensor(x.meta["example_value"])
|
||||||
|
if fake_mode
|
||||||
|
else x.meta["example_value"],
|
||||||
|
output_node.args[0],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args)
|
||||||
|
if fake_mode and fake_mode.shape_env is None:
|
||||||
|
fake_mode.shape_env = ShapeEnv()
|
||||||
|
out_shuffle_graph = make_fx(
|
||||||
|
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
|
||||||
|
)(*flat_out_shuffle_args)
|
||||||
|
|
||||||
def pytree_call(*args, **kwargs):
|
assert out_shuffle.out_spec is not None
|
||||||
import torch.export._unlift
|
return PyTreeifyOutput(
|
||||||
|
backend_input.graph_module,
|
||||||
|
in_spec,
|
||||||
|
in_shuffle_graph,
|
||||||
|
len(flat_real_args),
|
||||||
|
out_shuffle.out_spec,
|
||||||
|
out_shuffle_graph,
|
||||||
|
root=root, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs))
|
|
||||||
if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}"
|
|
||||||
)
|
|
||||||
flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args))
|
|
||||||
assert out_shuffle.out_spec is not None
|
|
||||||
return pytree.tree_unflatten(
|
|
||||||
out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(mod, torch.nn.Module):
|
def normalize_graph_module(gm):
|
||||||
compiled_mod = copy.copy(mod)
|
for node in gm.graph.nodes:
|
||||||
compiled_mod.forward = types.MethodType(pytree_call, compiled_mod)
|
if node.op == "placeholder":
|
||||||
if not hasattr(compiled_mod, "meta"):
|
node.meta["val"] = node.meta["example_value"]
|
||||||
compiled_mod.meta = {} # type: ignore[attr-defined]
|
|
||||||
if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta:
|
|
||||||
compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode
|
|
||||||
return compiled_mod
|
|
||||||
elif inspect.ismethod(mod):
|
|
||||||
return types.MethodType(pytree_call, mod.__self__)
|
|
||||||
else:
|
|
||||||
return pytree_call
|
|
||||||
|
|
||||||
|
|
||||||
def dynamo_graph_capture_for_export(
|
def dynamo_graph_capture_for_export(
|
||||||
mod: Callable[..., Any],
|
mod: Callable[..., Any],
|
||||||
|
constraints: Optional[list[Constraint]] = None,
|
||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
assert not torch._dynamo.config.install_free_tensors
|
||||||
with (
|
with (
|
||||||
get_metrics_context(),
|
get_metrics_context(),
|
||||||
dynamo_timed("fullgraph_capture"),
|
dynamo_timed("fullgraph_capture"),
|
||||||
):
|
):
|
||||||
out = fullgraph_capture(mod, args, kwargs)
|
out = fullgraph_capture(
|
||||||
|
mod,
|
||||||
|
args,
|
||||||
|
kwargs,
|
||||||
|
constraints=constraints,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO filter out side effects.
|
# TODO filter out side effects.
|
||||||
|
pyt = pytreeify(out, mod, args, kwargs)
|
||||||
|
|
||||||
return pytreeify(out, mod, args, kwargs)
|
graph_module = pyt.graph_module
|
||||||
|
tree_leaf_names = [
|
||||||
|
graph_module.graph._graph_namespace.create_name(f"_tree_leaf_{i}", None)
|
||||||
|
for i in range(pyt.num_flat_args)
|
||||||
|
]
|
||||||
|
graph_module.graph._codegen = _ExportCodeGen(
|
||||||
|
_PyTreeInfo(
|
||||||
|
# TODO we should be able to use the names from dynamo graph directly.
|
||||||
|
argument_names(inspect.signature(mod), args, kwargs),
|
||||||
|
pyt.in_spec,
|
||||||
|
pyt.out_spec,
|
||||||
|
),
|
||||||
|
pyt.in_shuffle_graph,
|
||||||
|
pyt.out_shuffle_graph,
|
||||||
|
tree_leaf_names,
|
||||||
|
pyt.root,
|
||||||
|
) # type: ignore[attr-defined]
|
||||||
|
normalize_graph_module(graph_module)
|
||||||
|
if pyt.root is not None:
|
||||||
|
graph_module._parameters = pyt.root._parameters.copy()
|
||||||
|
graph_module._buffers = pyt.root._buffers.copy()
|
||||||
|
assert all(not hasattr(graph_module, m) for m in pyt.root._modules)
|
||||||
|
graph_module._modules.update(pyt.root._modules)
|
||||||
|
graph_module._non_persistent_buffers_set = (
|
||||||
|
pyt.root._non_persistent_buffers_set.copy()
|
||||||
|
)
|
||||||
|
graph_module._in_spec = pyt.in_spec
|
||||||
|
graph_module._out_spec = pyt.out_spec
|
||||||
|
assert not hasattr(graph_module, "_in_shuffle_graph")
|
||||||
|
assert not hasattr(graph_module, "_out_shuffle_graph")
|
||||||
|
graph_module._in_shuffle_graph = pyt.in_shuffle_graph
|
||||||
|
graph_module._out_shuffle_graph = pyt.out_shuffle_graph
|
||||||
|
delattr(graph_module, "_param_name_to_source")
|
||||||
|
graph_module.recompile()
|
||||||
|
graph_module.meta["module_call_specs"] = (
|
||||||
|
out.graph_capture_output.output_graph.export_metadata.module_call_spec
|
||||||
|
)
|
||||||
|
assert out.backend_input is not None
|
||||||
|
graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined]
|
||||||
|
return graph_module
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2747,12 +2747,14 @@ class DynamoTracerOutput:
|
||||||
error_on_graph_break: bool
|
error_on_graph_break: bool
|
||||||
is_tracing_resume_prologue: bool
|
is_tracing_resume_prologue: bool
|
||||||
output_graph: Optional[OutputGraph]
|
output_graph: Optional[OutputGraph]
|
||||||
|
closure: Optional[tuple[Any, ...]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
|
self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
self.error_on_graph_break = tracer.error_on_graph_break
|
self.error_on_graph_break = tracer.error_on_graph_break
|
||||||
self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
|
self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
|
||||||
|
self.closure = tracer.closure
|
||||||
if error:
|
if error:
|
||||||
self.output_graph = None
|
self.output_graph = None
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -4217,6 +4217,7 @@ class InstructionTranslatorBase(
|
||||||
self.f_builtins: dict[str, Any] = f_builtins
|
self.f_builtins: dict[str, Any] = f_builtins
|
||||||
self.code_options: dict[str, Any] = code_options
|
self.code_options: dict[str, Any] = code_options
|
||||||
self.f_code: types.CodeType = f_code
|
self.f_code: types.CodeType = f_code
|
||||||
|
self.closure = closure
|
||||||
|
|
||||||
# Execution record for replaying errors
|
# Execution record for replaying errors
|
||||||
if closure is not None and config.replay_record_enabled:
|
if closure is not None and config.replay_record_enabled:
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||||
GuardOnDataDependentSymNode,
|
GuardOnDataDependentSymNode,
|
||||||
ShapeEnv,
|
ShapeEnv,
|
||||||
)
|
)
|
||||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
from torch.fx.graph import _PyTreeInfo
|
||||||
from torch.utils._pytree import TreeSpec
|
from torch.utils._pytree import TreeSpec
|
||||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||||
|
|
||||||
|
|
@ -1537,12 +1537,10 @@ def _strict_export(
|
||||||
|
|
||||||
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
|
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
|
||||||
|
|
||||||
gm_torch_level.graph._codegen = _PyTreeCodeGen(
|
gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo(
|
||||||
_PyTreeInfo(
|
orig_arg_names,
|
||||||
orig_arg_names,
|
gm_torch_level._in_spec,
|
||||||
gm_torch_level._in_spec,
|
out_spec,
|
||||||
out_spec,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
gm_torch_level.recompile()
|
gm_torch_level.recompile()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1489,12 +1489,20 @@ def wrap_key(
|
||||||
|
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
|
def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
|
||||||
|
nonlocal tensors
|
||||||
|
|
||||||
flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
|
flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
|
||||||
assert len(flat_proxies) == len(flat_tensors)
|
assert len(flat_proxies) == len(flat_tensors)
|
||||||
with disable_proxy_modes_tracing() as m:
|
with disable_proxy_modes_tracing() as m:
|
||||||
assert isinstance(m, ProxyTorchDispatchMode)
|
assert isinstance(m, ProxyTorchDispatchMode)
|
||||||
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
|
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
|
||||||
|
|
||||||
|
if getattr(tracer, "proxy_module_inputs", False):
|
||||||
|
tensors = [ # type: ignore[assignment, var-annotated]
|
||||||
|
p if isinstance(t, torch.nn.Module) else t
|
||||||
|
for t, p in zip(tensors, proxies) # type: ignore[arg-type]
|
||||||
|
]
|
||||||
|
|
||||||
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
|
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
|
||||||
return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined]
|
return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
@ -2208,6 +2216,7 @@ class _MakefxTracer:
|
||||||
_error_on_data_dependent_ops: bool,
|
_error_on_data_dependent_ops: bool,
|
||||||
record_stack_traces: bool = False,
|
record_stack_traces: bool = False,
|
||||||
parent_tracer: Optional[_MakefxTracer] = None,
|
parent_tracer: Optional[_MakefxTracer] = None,
|
||||||
|
proxy_module_inputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Configurations that are used to initialize the context managers and their states.
|
# Configurations that are used to initialize the context managers and their states.
|
||||||
# Should not modify them during tracing.
|
# Should not modify them during tracing.
|
||||||
|
|
@ -2240,6 +2249,7 @@ class _MakefxTracer:
|
||||||
)
|
)
|
||||||
self.record_stack_traces = record_stack_traces
|
self.record_stack_traces = record_stack_traces
|
||||||
self.parent_tracer: Optional[_MakefxTracer] = parent_tracer
|
self.parent_tracer: Optional[_MakefxTracer] = parent_tracer
|
||||||
|
self.proxy_module_inputs = proxy_module_inputs
|
||||||
|
|
||||||
def _checkpoint_modes(self) -> list[Any]:
|
def _checkpoint_modes(self) -> list[Any]:
|
||||||
return [
|
return [
|
||||||
|
|
@ -2349,6 +2359,7 @@ class _MakefxTracer:
|
||||||
self.python_dispatcher_mode = enable_python_dispatcher()
|
self.python_dispatcher_mode = enable_python_dispatcher()
|
||||||
|
|
||||||
self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer)
|
self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer)
|
||||||
|
fx_tracer.proxy_module_inputs = self.proxy_module_inputs # type: ignore[union-attr]
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _init_modes_from_parent(
|
def _init_modes_from_parent(
|
||||||
|
|
@ -2551,6 +2562,7 @@ def make_fx(
|
||||||
_allow_fake_constant: bool = False,
|
_allow_fake_constant: bool = False,
|
||||||
_error_on_data_dependent_ops: bool = True,
|
_error_on_data_dependent_ops: bool = True,
|
||||||
record_stack_traces: bool = False,
|
record_stack_traces: bool = False,
|
||||||
|
proxy_module_inputs: bool = False,
|
||||||
) -> Callable[..., GraphModule]:
|
) -> Callable[..., GraphModule]:
|
||||||
"""
|
"""
|
||||||
Given a function f, return a new function which when executed with valid
|
Given a function f, return a new function which when executed with valid
|
||||||
|
|
@ -2574,6 +2586,7 @@ def make_fx(
|
||||||
_error_on_data_dependent_ops,
|
_error_on_data_dependent_ops,
|
||||||
record_stack_traces=record_stack_traces
|
record_stack_traces=record_stack_traces
|
||||||
or config.trace.provenance_tracking_level == 1,
|
or config.trace.provenance_tracking_level == 1,
|
||||||
|
proxy_module_inputs=proxy_module_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
|
|
|
||||||
|
|
@ -930,6 +930,42 @@ class _PyTreeCodeGen(CodeGen):
|
||||||
else:
|
else:
|
||||||
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
|
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
|
||||||
|
|
||||||
|
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
|
||||||
|
# when kwargs is present, in_spec is tuple(args, kwargs)
|
||||||
|
has_args_kwargs_tuple = (
|
||||||
|
self.pytree_info.in_spec.type is tuple
|
||||||
|
and self.pytree_info.in_spec.num_children == 2
|
||||||
|
and self.pytree_info.in_spec.children_specs[0].type is tuple
|
||||||
|
and self.pytree_info.in_spec.children_specs[1].type is dict
|
||||||
|
)
|
||||||
|
fn_kwargs = "{}"
|
||||||
|
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
|
||||||
|
if has_args_kwargs_tuple:
|
||||||
|
count_args = self.pytree_info.in_spec.children_specs[0].num_children
|
||||||
|
fn_args = self.pytree_info.orig_args[:count_args]
|
||||||
|
fn_kwargs = (
|
||||||
|
"{"
|
||||||
|
+ ", ".join(
|
||||||
|
f"'{k}':{v}"
|
||||||
|
for k, v in zip(
|
||||||
|
self.pytree_info.in_spec.children_specs[1].context,
|
||||||
|
self.pytree_info.orig_args[count_args:],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
+ "}"
|
||||||
|
)
|
||||||
|
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
|
||||||
|
|
||||||
|
# in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
|
||||||
|
# we need to split it to two lines:
|
||||||
|
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
|
||||||
|
# one for code: `var1, var2, = function_call()`
|
||||||
|
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
|
||||||
|
bindings = self._format_annotations(free_vars, expanded_def)
|
||||||
|
bindings += f"""
|
||||||
|
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
|
||||||
|
return bindings
|
||||||
|
|
||||||
def gen_fn_def(
|
def gen_fn_def(
|
||||||
self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
|
self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
|
||||||
):
|
):
|
||||||
|
|
@ -962,39 +998,7 @@ class _PyTreeCodeGen(CodeGen):
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(free_vars) > 0: # pytree has placeholders in it
|
if len(free_vars) > 0: # pytree has placeholders in it
|
||||||
# when kwargs is present, in_spec is tuple(args, kwargs)
|
fn_definition += self.gen_var_bindings(fn_args, free_vars, expanded_def)
|
||||||
has_args_kwargs_tuple = (
|
|
||||||
self.pytree_info.in_spec.type is tuple
|
|
||||||
and self.pytree_info.in_spec.num_children == 2
|
|
||||||
and self.pytree_info.in_spec.children_specs[0].type is tuple
|
|
||||||
and self.pytree_info.in_spec.children_specs[1].type is dict
|
|
||||||
)
|
|
||||||
fn_kwargs = "{}"
|
|
||||||
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
|
|
||||||
if has_args_kwargs_tuple:
|
|
||||||
count_args = self.pytree_info.in_spec.children_specs[0].num_children
|
|
||||||
fn_args = self.pytree_info.orig_args[:count_args]
|
|
||||||
fn_kwargs = (
|
|
||||||
"{"
|
|
||||||
+ ", ".join(
|
|
||||||
f"'{k}':{v}"
|
|
||||||
for k, v in zip(
|
|
||||||
self.pytree_info.in_spec.children_specs[1].context,
|
|
||||||
self.pytree_info.orig_args[count_args:],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
+ "}"
|
|
||||||
)
|
|
||||||
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
|
|
||||||
|
|
||||||
# in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
|
|
||||||
# we need to split it to two lines:
|
|
||||||
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
|
|
||||||
# one for code: `var1, var2, = function_call()`
|
|
||||||
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
|
|
||||||
fn_definition += self._format_annotations(free_vars, expanded_def)
|
|
||||||
fn_definition += f"""
|
|
||||||
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
|
|
||||||
return fn_definition
|
return fn_definition
|
||||||
|
|
||||||
def generate_output(self, output_args, *, descs: Optional[Any] = None):
|
def generate_output(self, output_args, *, descs: Optional[Any] = None):
|
||||||
|
|
@ -1014,6 +1018,52 @@ class _PyTreeCodeGen(CodeGen):
|
||||||
return super().generate_output(output_args, descs=descs)
|
return super().generate_output(output_args, descs=descs)
|
||||||
|
|
||||||
|
|
||||||
|
class _ExportCodeGen(_PyTreeCodeGen):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pytree_info: _PyTreeInfo,
|
||||||
|
in_shuffle_graph: "GraphModule",
|
||||||
|
out_shuffle_graph: "GraphModule",
|
||||||
|
tree_leaf_names: list[str],
|
||||||
|
root: Optional[torch.nn.Module],
|
||||||
|
):
|
||||||
|
super().__init__(pytree_info)
|
||||||
|
self.in_shuffle_graph = in_shuffle_graph
|
||||||
|
self.out_shuffle_graph = out_shuffle_graph
|
||||||
|
self.tree_leaf_names = tree_leaf_names
|
||||||
|
self.root = root
|
||||||
|
|
||||||
|
def process_inputs(self, *inputs: Any) -> Any:
|
||||||
|
flat_args = super().process_inputs(*inputs)
|
||||||
|
if self.root is not None:
|
||||||
|
flat_args = (self.root, *flat_args)
|
||||||
|
self.flat_args = flat_args
|
||||||
|
return self.in_shuffle_graph(*flat_args)
|
||||||
|
|
||||||
|
def process_outputs(self, out: Any) -> Any:
|
||||||
|
flat_outs = self.out_shuffle_graph(*self.flat_args, *out)
|
||||||
|
del self.flat_args
|
||||||
|
ret = super().process_outputs(flat_outs)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def gen_fn_def(self, *args, **kwargs) -> str:
|
||||||
|
fn_def = super().gen_fn_def(*args, **kwargs)
|
||||||
|
return fn_def
|
||||||
|
|
||||||
|
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
|
||||||
|
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
|
||||||
|
fn_signature: str = f"{', '.join(fn_args)}"
|
||||||
|
if self.root is not None:
|
||||||
|
fn_signature = f"self, {fn_signature}"
|
||||||
|
return f"""
|
||||||
|
{", ".join(self.tree_leaf_names)}, = pytree.tree_leaves(({fn_signature},))
|
||||||
|
{", ".join(without_annotation)}, = self._in_shuffle_graph({", ".join(self.tree_leaf_names)})"""
|
||||||
|
|
||||||
|
def generate_output(self, output_args, *args, **kwargs) -> str:
|
||||||
|
output = f"self._out_shuffle_graph({', '.join(self.tree_leaf_names)}, {', '.join([str(a) for a in output_args])})"
|
||||||
|
return f"return pytree.tree_unflatten({output}, self._out_spec)"
|
||||||
|
|
||||||
|
|
||||||
class _FindNodesLookupTable:
|
class _FindNodesLookupTable:
|
||||||
"""
|
"""
|
||||||
Side table for the graph for the purpose of doing fast queries
|
Side table for the graph for the purpose of doing fast queries
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user