mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Unified graph capture with fullgraph_capture. (#165562)
Summary: _dynamo_graph_capture_for_export in the current form has the compability issue with the main torch.compile() path despite we reuse fullgraph_capture as the bytecode tracer. The reason is that we flip on many export specific flags and even trace with a wrapped function which will cause divergence with torch.compile() again. This PR instead creates a new implementation of dynamo_graph_capture_for_export which 100% relies on fullgraph capture and post-processing on CaptureOutput so that we can avoid the inversion of phases in PT2 compiler stack. This also benefits precompile workflow since we want to have a feature that only accepts pytree inputs and ship portable python wrappers in package. In other words, I think the code here is sharable between export and precompile for exporting portable graph. Test Plan: ===================================================================== test session starts ===================================================================== platform linux -- Python 3.12.11, pytest-7.3.2, pluggy-1.6.0 rootdir: /data/users/zhxchen17/pytorch configfile: pytest.ini plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, subtests-0.13.1, rerunfailures-14.0, flakefinder-1.1.0, cpp-2.3.0, anyio-4.10.0 collected 9 items Running 9 items in this shard test/distributed/tensor/test_dtensor_export.py ........x [100%] ================================================================ 8 passed, 1 xfailed in 11.42s ================================================================ Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/165562 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
parent
291712026b
commit
757975ad50
|
|
@ -6,7 +6,10 @@ import unittest
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
dynamo_graph_capture_for_export,
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch._guards import tracing, TracingContext
|
||||
|
|
@ -96,6 +99,13 @@ def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
|
|||
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs):
|
||||
gm = dynamo_graph_capture_for_export(model)(inputs)
|
||||
fake_mode = gm.meta.get("fake_mode", None)
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
return aot_export_joint_with_descriptors_alone(gm, inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
|
|
@ -288,6 +298,7 @@ class DTensorExportTest(TestCase):
|
|||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors_alone,
|
||||
],
|
||||
|
|
@ -307,7 +318,21 @@ class DTensorExportTest(TestCase):
|
|||
def test_annotate_aot_export_joint_with_descriptors_alone(self):
|
||||
self._run_test(aot_export_joint_with_descriptors_alone, True)
|
||||
|
||||
def test_dynamic_shapes(self):
|
||||
@parametrize(
|
||||
"export_fn_with_answer",
|
||||
[
|
||||
(
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
|
||||
),
|
||||
(
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes(self, export_fn_with_answer):
|
||||
export_fn, answer = export_fn_with_answer
|
||||
dp_degree = 2
|
||||
tp_degree = self.world_size // dp_degree
|
||||
|
||||
|
|
@ -331,7 +356,7 @@ class DTensorExportTest(TestCase):
|
|||
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
|
||||
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
|
||||
|
||||
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(tp_model, inputs)
|
||||
joint_gm = export_fn(tp_model, inputs)
|
||||
|
||||
res = []
|
||||
for node in joint_gm.graph.nodes:
|
||||
|
|
@ -341,12 +366,16 @@ class DTensorExportTest(TestCase):
|
|||
if isinstance(fake_val, torch._subclasses.fake_tensor.FakeTensor):
|
||||
res.append(list(fake_val.shape))
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(res),
|
||||
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
|
||||
)
|
||||
self.assertEqual(str(res), answer)
|
||||
|
||||
def test_einsum_dtensor_export(self):
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
dynamo_graph_capture_for_export,
|
||||
_dynamo_graph_capture_for_export,
|
||||
],
|
||||
)
|
||||
def test_einsum_dtensor_export(self, export_fn):
|
||||
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
|
||||
world_size = 4
|
||||
# Create device mesh
|
||||
|
|
@ -366,9 +395,7 @@ class DTensorExportTest(TestCase):
|
|||
output = model(x_dtensor, y_dtensor, z_dtensor)
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(
|
||||
x_dtensor, y_dtensor, z_dtensor
|
||||
)
|
||||
gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor)
|
||||
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
|
|
|
|||
|
|
@ -471,6 +471,67 @@ from user code:
|
|||
assert hasattr(backend_result.compiled_fn, "serialize")
|
||||
self.assertIsNotNone(backend_result.compiled_fn.serialize)
|
||||
|
||||
def test_fullgraph_capture_with_pytree_module(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
self.linear1 = torch.nn.Linear(3, 3)
|
||||
self.linear2 = torch.nn.Linear(3, 3)
|
||||
self.linear3 = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return {
|
||||
"y": self.linear2(x[2] + 1),
|
||||
"z": self.linear3(x[1] - 1),
|
||||
"w": self.linear(x[0]["b"] + 2),
|
||||
"v": self.linear1(x[0]["a"] - 2),
|
||||
}
|
||||
|
||||
mod = Module()
|
||||
compiled_mod = dynamo_graph_capture_for_export(mod)(
|
||||
(
|
||||
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
|
||||
torch.randn(3, 3),
|
||||
torch.randn(3, 3),
|
||||
)
|
||||
)
|
||||
|
||||
inputs = (
|
||||
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
|
||||
torch.randn(3, 3),
|
||||
torch.randn(3, 3),
|
||||
)
|
||||
self.assertEqual(compiled_mod(inputs), mod(inputs))
|
||||
|
||||
def test_fullgraph_capture_with_pytree_func(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
def foo(x):
|
||||
return {
|
||||
"y": x[2] + 1,
|
||||
"z": x[1] - 1,
|
||||
"w": x[0]["b"] + 2,
|
||||
"v": x[0]["a"] - 2,
|
||||
}
|
||||
|
||||
compiled_foo = dynamo_graph_capture_for_export(foo)(
|
||||
(
|
||||
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
|
||||
torch.randn(2, 3),
|
||||
torch.randn(3, 4),
|
||||
)
|
||||
)
|
||||
|
||||
inputs = (
|
||||
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
|
||||
torch.randn(2, 3),
|
||||
torch.randn(3, 4),
|
||||
)
|
||||
self.assertEqual(compiled_foo(inputs), foo(inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -402,6 +402,43 @@ def forward(self, x):
|
|||
|
||||
self.assertEqual(res_export, res_eager)
|
||||
|
||||
def test_dynamo_graph_capture(self):
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, dct, lst, bleh):
|
||||
x = dct["a"] * lst[1][0]
|
||||
y = dct["b"] * lst[0]
|
||||
out_dict = {}
|
||||
|
||||
# Mutate and get a new entry in there
|
||||
lst_copy = lst.copy()
|
||||
lst_copy.append(lst[0])
|
||||
out_dict["a"] = x
|
||||
out_dict["b"] = y
|
||||
return (
|
||||
dct["a"],
|
||||
out_dict["b"],
|
||||
bleh,
|
||||
lst_copy[-1],
|
||||
out_dict["a"],
|
||||
[5, 6],
|
||||
)
|
||||
|
||||
foo = Foo()
|
||||
|
||||
def make_inputs():
|
||||
return (
|
||||
{"a": torch.randn(2, 3), "b": torch.randn(2, 3)},
|
||||
[torch.randn(2, 3), (torch.randn(2, 3),)],
|
||||
torch.randn(2, 3),
|
||||
)
|
||||
|
||||
trace_inputs = make_inputs()
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
test_inputs = make_inputs()
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -886,6 +886,7 @@ class DynamoOutput:
|
|||
return GraphCaptureOutput(
|
||||
OutputGraphCommon(
|
||||
output_graph.dump_guards_state(),
|
||||
output_graph.import_sources,
|
||||
output_graph.shape_env,
|
||||
output_graph.export_metadata,
|
||||
output_graph.tracked_fakes_id_to_source,
|
||||
|
|
@ -960,6 +961,27 @@ class CaptureOutput:
|
|||
# BackendInput can be None when dynamo didn't compile any graph (no tensor op)
|
||||
backend_input: Optional[BackendInput]
|
||||
|
||||
def forward_callable(self) -> Callable[..., Any]:
|
||||
import importlib
|
||||
|
||||
# TODO code sharing
|
||||
import_sources = self.graph_capture_output.output_graph.import_sources
|
||||
assert self.backend_input is not None
|
||||
backend_id = self.backend_input.backend_id
|
||||
import_sources = {
|
||||
alias: importlib.import_module(module_name)
|
||||
for alias, module_name in import_sources.items()
|
||||
}
|
||||
f_globals = {
|
||||
**import_sources,
|
||||
backend_id: self.backend_input.graph_module,
|
||||
}
|
||||
return types.FunctionType(
|
||||
self.graph_capture_output.bytecode,
|
||||
f_globals,
|
||||
closure=(),
|
||||
)
|
||||
|
||||
|
||||
def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import traceback
|
||||
import types
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
|
|
@ -26,6 +28,7 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.utils._pytree import TreeSpec
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -446,6 +449,140 @@ def _suggest_or_raise_constraint_violation(
|
|||
raise constraint_violation_error
|
||||
|
||||
|
||||
def pytreeify(
|
||||
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Given a dynamo capture output, return a callable graph module that
|
||||
contain the following information:
|
||||
1. input/output pytree spec
|
||||
2. input/output shuffle functions
|
||||
Input shuffle functions are the converters taking pytree falttened inputs
|
||||
and reorder them to the calling convention of dynamo raw graph module.
|
||||
Output shuffle functions are the converters taking the outputs of the
|
||||
dynamo raw graph module and convert them to the pytree format.
|
||||
|
||||
This function will replay any side effects that happened during the bytecode,
|
||||
so it is important to check against side effects before calling this function.
|
||||
"""
|
||||
assert out.backend_input is not None
|
||||
backend_input = out.backend_input
|
||||
backend = out.backend_input.graph_module
|
||||
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
args = (mod,) + args
|
||||
elif inspect.ismethod(mod):
|
||||
args = (mod.__self__,) + args
|
||||
|
||||
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
class Yield(Exception):
|
||||
pass
|
||||
|
||||
class InShuffle(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mod = mod
|
||||
self.num_inputs = len(flat_real_args)
|
||||
self.gm_inputs = None
|
||||
|
||||
def forward(self, *flat_proxy_args):
|
||||
args, kwargs = pytree.tree_unflatten(
|
||||
[flat_proxy_args[i] for i in range(self.num_inputs)], in_spec
|
||||
)
|
||||
|
||||
def backend_dummy(*example_inputs):
|
||||
self.gm_inputs = example_inputs
|
||||
raise Yield
|
||||
|
||||
backend_input.graph_module = backend_dummy # type: ignore[assignment]
|
||||
try:
|
||||
out.forward_callable()(*args, **kwargs)
|
||||
except Yield:
|
||||
assert self.gm_inputs is not None
|
||||
return self.gm_inputs
|
||||
finally:
|
||||
backend_input.graph_module = backend
|
||||
raise RuntimeError
|
||||
|
||||
in_shuffle_graph = torch.fx.symbolic_trace(InShuffle())
|
||||
|
||||
class OutShuffle(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_inputs = len(flat_real_args)
|
||||
self.num_outputs = len(
|
||||
next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0]
|
||||
)
|
||||
self.out_spec: Optional[TreeSpec] = None
|
||||
|
||||
def forward(self, *flat_proxy_args):
|
||||
args, kwargs = pytree.tree_unflatten(
|
||||
[flat_proxy_args[i] for i in range(self.num_inputs)], in_spec
|
||||
)
|
||||
|
||||
def backend_dummy(*example_inputs):
|
||||
return [
|
||||
flat_proxy_args[self.num_inputs + i]
|
||||
for i in range(self.num_outputs)
|
||||
]
|
||||
|
||||
backend_input.graph_module = backend_dummy # type: ignore[assignment]
|
||||
try:
|
||||
results = out.forward_callable()(*args, **kwargs)
|
||||
finally:
|
||||
backend_input.graph_module = backend
|
||||
ret, self.out_spec = pytree.tree_flatten(results)
|
||||
return ret
|
||||
|
||||
out_shuffle = OutShuffle()
|
||||
out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle)
|
||||
|
||||
def pytree_call(*args, **kwargs):
|
||||
import torch.export._unlift
|
||||
|
||||
flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs))
|
||||
if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec):
|
||||
raise RuntimeError(
|
||||
f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}"
|
||||
)
|
||||
flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args))
|
||||
assert out_shuffle.out_spec is not None
|
||||
return pytree.tree_unflatten(
|
||||
out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec
|
||||
)
|
||||
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
compiled_mod = copy.copy(mod)
|
||||
compiled_mod.forward = types.MethodType(pytree_call, compiled_mod)
|
||||
if not hasattr(compiled_mod, "meta"):
|
||||
compiled_mod.meta = {} # type: ignore[attr-defined]
|
||||
if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta:
|
||||
compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode
|
||||
return compiled_mod
|
||||
elif inspect.ismethod(mod):
|
||||
return types.MethodType(pytree_call, mod.__self__)
|
||||
else:
|
||||
return pytree_call
|
||||
|
||||
|
||||
def dynamo_graph_capture_for_export(
|
||||
mod: Callable[..., Any],
|
||||
) -> Callable[..., Any]:
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
with (
|
||||
get_metrics_context(),
|
||||
dynamo_timed("fullgraph_capture"),
|
||||
):
|
||||
out = fullgraph_capture(mod, args, kwargs)
|
||||
|
||||
# TODO filter out side effects.
|
||||
|
||||
return pytreeify(out, mod, args, kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _dynamo_graph_capture_for_export(
|
||||
mod: Callable[..., Any],
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -463,6 +463,7 @@ class OutputGraphCommon(OutputGraphGuardsState):
|
|||
def __init__(
|
||||
self,
|
||||
output_graph_guards_state: OutputGraphGuardsState,
|
||||
import_sources: Optional[dict[str, str]] = None,
|
||||
shape_env: Optional[ShapeEnv] = None,
|
||||
export_metadata: Optional[ExportMetaData] = None,
|
||||
tracked_fakes_id_to_source: Optional[dict[int, list[Source]]] = None,
|
||||
|
|
@ -485,6 +486,7 @@ class OutputGraphCommon(OutputGraphGuardsState):
|
|||
output_graph_guards_state.name_of_builtins_dict_key_in_fglobals,
|
||||
)
|
||||
|
||||
self.import_sources = import_sources or {}
|
||||
# The following fields are currently known to be used by clients.
|
||||
# In particular, we need:
|
||||
# - shape_env, for building guards
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user