mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Unified graph capture with fullgraph_capture. (#165562)
Summary: _dynamo_graph_capture_for_export in the current form has the compability issue with the main torch.compile() path despite we reuse fullgraph_capture as the bytecode tracer. The reason is that we flip on many export specific flags and even trace with a wrapped function which will cause divergence with torch.compile() again. This PR instead creates a new implementation of dynamo_graph_capture_for_export which 100% relies on fullgraph capture and post-processing on CaptureOutput so that we can avoid the inversion of phases in PT2 compiler stack. This also benefits precompile workflow since we want to have a feature that only accepts pytree inputs and ship portable python wrappers in package. In other words, I think the code here is sharable between export and precompile for exporting portable graph. Test Plan: ===================================================================== test session starts ===================================================================== platform linux -- Python 3.12.11, pytest-7.3.2, pluggy-1.6.0 rootdir: /data/users/zhxchen17/pytorch configfile: pytest.ini plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, subtests-0.13.1, rerunfailures-14.0, flakefinder-1.1.0, cpp-2.3.0, anyio-4.10.0 collected 9 items Running 9 items in this shard test/distributed/tensor/test_dtensor_export.py ........x [100%] ================================================================ 8 passed, 1 xfailed in 11.42s ================================================================ Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/165562 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
parent
291712026b
commit
757975ad50
|
|
@ -6,7 +6,10 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.fx.traceback as fx_traceback
|
import torch.fx.traceback as fx_traceback
|
||||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
from torch._dynamo.functional_export import (
|
||||||
|
_dynamo_graph_capture_for_export,
|
||||||
|
dynamo_graph_capture_for_export,
|
||||||
|
)
|
||||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||||
from torch._guards import tracing, TracingContext
|
from torch._guards import tracing, TracingContext
|
||||||
|
|
@ -96,6 +99,13 @@ def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
|
||||||
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
|
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs):
|
||||||
|
gm = dynamo_graph_capture_for_export(model)(inputs)
|
||||||
|
fake_mode = gm.meta.get("fake_mode", None)
|
||||||
|
with tracing(TracingContext(fake_mode)):
|
||||||
|
return aot_export_joint_with_descriptors_alone(gm, inputs)
|
||||||
|
|
||||||
|
|
||||||
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
|
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
|
||||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||||
# TODO: switch to use the official graph_capture API once it is ready
|
# TODO: switch to use the official graph_capture API once it is ready
|
||||||
|
|
@ -288,6 +298,7 @@ class DTensorExportTest(TestCase):
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"export_fn",
|
"export_fn",
|
||||||
[
|
[
|
||||||
|
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||||
graph_capture_and_aot_export_joint_with_descriptors,
|
graph_capture_and_aot_export_joint_with_descriptors,
|
||||||
aot_export_joint_with_descriptors_alone,
|
aot_export_joint_with_descriptors_alone,
|
||||||
],
|
],
|
||||||
|
|
@ -307,7 +318,21 @@ class DTensorExportTest(TestCase):
|
||||||
def test_annotate_aot_export_joint_with_descriptors_alone(self):
|
def test_annotate_aot_export_joint_with_descriptors_alone(self):
|
||||||
self._run_test(aot_export_joint_with_descriptors_alone, True)
|
self._run_test(aot_export_joint_with_descriptors_alone, True)
|
||||||
|
|
||||||
def test_dynamic_shapes(self):
|
@parametrize(
|
||||||
|
"export_fn_with_answer",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||||
|
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
graph_capture_and_aot_export_joint_with_descriptors,
|
||||||
|
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_dynamic_shapes(self, export_fn_with_answer):
|
||||||
|
export_fn, answer = export_fn_with_answer
|
||||||
dp_degree = 2
|
dp_degree = 2
|
||||||
tp_degree = self.world_size // dp_degree
|
tp_degree = self.world_size // dp_degree
|
||||||
|
|
||||||
|
|
@ -331,7 +356,7 @@ class DTensorExportTest(TestCase):
|
||||||
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
|
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
|
||||||
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
|
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
|
||||||
|
|
||||||
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(tp_model, inputs)
|
joint_gm = export_fn(tp_model, inputs)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for node in joint_gm.graph.nodes:
|
for node in joint_gm.graph.nodes:
|
||||||
|
|
@ -341,12 +366,16 @@ class DTensorExportTest(TestCase):
|
||||||
if isinstance(fake_val, torch._subclasses.fake_tensor.FakeTensor):
|
if isinstance(fake_val, torch._subclasses.fake_tensor.FakeTensor):
|
||||||
res.append(list(fake_val.shape))
|
res.append(list(fake_val.shape))
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertEqual(str(res), answer)
|
||||||
str(res),
|
|
||||||
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_einsum_dtensor_export(self):
|
@parametrize(
|
||||||
|
"export_fn",
|
||||||
|
[
|
||||||
|
dynamo_graph_capture_for_export,
|
||||||
|
_dynamo_graph_capture_for_export,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_einsum_dtensor_export(self, export_fn):
|
||||||
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
|
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
|
||||||
world_size = 4
|
world_size = 4
|
||||||
# Create device mesh
|
# Create device mesh
|
||||||
|
|
@ -366,9 +395,7 @@ class DTensorExportTest(TestCase):
|
||||||
output = model(x_dtensor, y_dtensor, z_dtensor)
|
output = model(x_dtensor, y_dtensor, z_dtensor)
|
||||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||||
# TODO: switch to use the official graph_capture API once it is ready
|
# TODO: switch to use the official graph_capture API once it is ready
|
||||||
gm = _dynamo_graph_capture_for_export(model)(
|
gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor)
|
||||||
x_dtensor, y_dtensor, z_dtensor
|
|
||||||
)
|
|
||||||
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
|
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
|
||||||
self.assertEqual(output, output_gm)
|
self.assertEqual(output, output_gm)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -471,6 +471,67 @@ from user code:
|
||||||
assert hasattr(backend_result.compiled_fn, "serialize")
|
assert hasattr(backend_result.compiled_fn, "serialize")
|
||||||
self.assertIsNotNone(backend_result.compiled_fn.serialize)
|
self.assertIsNotNone(backend_result.compiled_fn.serialize)
|
||||||
|
|
||||||
|
def test_fullgraph_capture_with_pytree_module(self):
|
||||||
|
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||||
|
|
||||||
|
class Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(3, 3)
|
||||||
|
self.linear1 = torch.nn.Linear(3, 3)
|
||||||
|
self.linear2 = torch.nn.Linear(3, 3)
|
||||||
|
self.linear3 = torch.nn.Linear(3, 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return {
|
||||||
|
"y": self.linear2(x[2] + 1),
|
||||||
|
"z": self.linear3(x[1] - 1),
|
||||||
|
"w": self.linear(x[0]["b"] + 2),
|
||||||
|
"v": self.linear1(x[0]["a"] - 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
mod = Module()
|
||||||
|
compiled_mod = dynamo_graph_capture_for_export(mod)(
|
||||||
|
(
|
||||||
|
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
|
||||||
|
torch.randn(3, 3),
|
||||||
|
torch.randn(3, 3),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = (
|
||||||
|
{"a": torch.randn(3, 3), "b": torch.randn(3, 3)},
|
||||||
|
torch.randn(3, 3),
|
||||||
|
torch.randn(3, 3),
|
||||||
|
)
|
||||||
|
self.assertEqual(compiled_mod(inputs), mod(inputs))
|
||||||
|
|
||||||
|
def test_fullgraph_capture_with_pytree_func(self):
|
||||||
|
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||||
|
|
||||||
|
def foo(x):
|
||||||
|
return {
|
||||||
|
"y": x[2] + 1,
|
||||||
|
"z": x[1] - 1,
|
||||||
|
"w": x[0]["b"] + 2,
|
||||||
|
"v": x[0]["a"] - 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled_foo = dynamo_graph_capture_for_export(foo)(
|
||||||
|
(
|
||||||
|
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
|
||||||
|
torch.randn(2, 3),
|
||||||
|
torch.randn(3, 4),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = (
|
||||||
|
{"a": torch.randn(4, 3), "b": torch.randn(3, 2)},
|
||||||
|
torch.randn(2, 3),
|
||||||
|
torch.randn(3, 4),
|
||||||
|
)
|
||||||
|
self.assertEqual(compiled_foo(inputs), foo(inputs))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -402,6 +402,43 @@ def forward(self, x):
|
||||||
|
|
||||||
self.assertEqual(res_export, res_eager)
|
self.assertEqual(res_export, res_eager)
|
||||||
|
|
||||||
|
def test_dynamo_graph_capture(self):
|
||||||
|
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||||
|
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, dct, lst, bleh):
|
||||||
|
x = dct["a"] * lst[1][0]
|
||||||
|
y = dct["b"] * lst[0]
|
||||||
|
out_dict = {}
|
||||||
|
|
||||||
|
# Mutate and get a new entry in there
|
||||||
|
lst_copy = lst.copy()
|
||||||
|
lst_copy.append(lst[0])
|
||||||
|
out_dict["a"] = x
|
||||||
|
out_dict["b"] = y
|
||||||
|
return (
|
||||||
|
dct["a"],
|
||||||
|
out_dict["b"],
|
||||||
|
bleh,
|
||||||
|
lst_copy[-1],
|
||||||
|
out_dict["a"],
|
||||||
|
[5, 6],
|
||||||
|
)
|
||||||
|
|
||||||
|
foo = Foo()
|
||||||
|
|
||||||
|
def make_inputs():
|
||||||
|
return (
|
||||||
|
{"a": torch.randn(2, 3), "b": torch.randn(2, 3)},
|
||||||
|
[torch.randn(2, 3), (torch.randn(2, 3),)],
|
||||||
|
torch.randn(2, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
trace_inputs = make_inputs()
|
||||||
|
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||||
|
test_inputs = make_inputs()
|
||||||
|
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -886,6 +886,7 @@ class DynamoOutput:
|
||||||
return GraphCaptureOutput(
|
return GraphCaptureOutput(
|
||||||
OutputGraphCommon(
|
OutputGraphCommon(
|
||||||
output_graph.dump_guards_state(),
|
output_graph.dump_guards_state(),
|
||||||
|
output_graph.import_sources,
|
||||||
output_graph.shape_env,
|
output_graph.shape_env,
|
||||||
output_graph.export_metadata,
|
output_graph.export_metadata,
|
||||||
output_graph.tracked_fakes_id_to_source,
|
output_graph.tracked_fakes_id_to_source,
|
||||||
|
|
@ -960,6 +961,27 @@ class CaptureOutput:
|
||||||
# BackendInput can be None when dynamo didn't compile any graph (no tensor op)
|
# BackendInput can be None when dynamo didn't compile any graph (no tensor op)
|
||||||
backend_input: Optional[BackendInput]
|
backend_input: Optional[BackendInput]
|
||||||
|
|
||||||
|
def forward_callable(self) -> Callable[..., Any]:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
# TODO code sharing
|
||||||
|
import_sources = self.graph_capture_output.output_graph.import_sources
|
||||||
|
assert self.backend_input is not None
|
||||||
|
backend_id = self.backend_input.backend_id
|
||||||
|
import_sources = {
|
||||||
|
alias: importlib.import_module(module_name)
|
||||||
|
for alias, module_name in import_sources.items()
|
||||||
|
}
|
||||||
|
f_globals = {
|
||||||
|
**import_sources,
|
||||||
|
backend_id: self.backend_input.graph_module,
|
||||||
|
}
|
||||||
|
return types.FunctionType(
|
||||||
|
self.graph_capture_output.bytecode,
|
||||||
|
f_globals,
|
||||||
|
closure=(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
import types
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
|
@ -26,6 +28,7 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
|
from torch.utils._pytree import TreeSpec
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -446,6 +449,140 @@ def _suggest_or_raise_constraint_violation(
|
||||||
raise constraint_violation_error
|
raise constraint_violation_error
|
||||||
|
|
||||||
|
|
||||||
|
def pytreeify(
|
||||||
|
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Given a dynamo capture output, return a callable graph module that
|
||||||
|
contain the following information:
|
||||||
|
1. input/output pytree spec
|
||||||
|
2. input/output shuffle functions
|
||||||
|
Input shuffle functions are the converters taking pytree falttened inputs
|
||||||
|
and reorder them to the calling convention of dynamo raw graph module.
|
||||||
|
Output shuffle functions are the converters taking the outputs of the
|
||||||
|
dynamo raw graph module and convert them to the pytree format.
|
||||||
|
|
||||||
|
This function will replay any side effects that happened during the bytecode,
|
||||||
|
so it is important to check against side effects before calling this function.
|
||||||
|
"""
|
||||||
|
assert out.backend_input is not None
|
||||||
|
backend_input = out.backend_input
|
||||||
|
backend = out.backend_input.graph_module
|
||||||
|
|
||||||
|
if isinstance(mod, torch.nn.Module):
|
||||||
|
args = (mod,) + args
|
||||||
|
elif inspect.ismethod(mod):
|
||||||
|
args = (mod.__self__,) + args
|
||||||
|
|
||||||
|
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
|
||||||
|
|
||||||
|
class Yield(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InShuffle(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mod = mod
|
||||||
|
self.num_inputs = len(flat_real_args)
|
||||||
|
self.gm_inputs = None
|
||||||
|
|
||||||
|
def forward(self, *flat_proxy_args):
|
||||||
|
args, kwargs = pytree.tree_unflatten(
|
||||||
|
[flat_proxy_args[i] for i in range(self.num_inputs)], in_spec
|
||||||
|
)
|
||||||
|
|
||||||
|
def backend_dummy(*example_inputs):
|
||||||
|
self.gm_inputs = example_inputs
|
||||||
|
raise Yield
|
||||||
|
|
||||||
|
backend_input.graph_module = backend_dummy # type: ignore[assignment]
|
||||||
|
try:
|
||||||
|
out.forward_callable()(*args, **kwargs)
|
||||||
|
except Yield:
|
||||||
|
assert self.gm_inputs is not None
|
||||||
|
return self.gm_inputs
|
||||||
|
finally:
|
||||||
|
backend_input.graph_module = backend
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
in_shuffle_graph = torch.fx.symbolic_trace(InShuffle())
|
||||||
|
|
||||||
|
class OutShuffle(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.num_inputs = len(flat_real_args)
|
||||||
|
self.num_outputs = len(
|
||||||
|
next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0]
|
||||||
|
)
|
||||||
|
self.out_spec: Optional[TreeSpec] = None
|
||||||
|
|
||||||
|
def forward(self, *flat_proxy_args):
|
||||||
|
args, kwargs = pytree.tree_unflatten(
|
||||||
|
[flat_proxy_args[i] for i in range(self.num_inputs)], in_spec
|
||||||
|
)
|
||||||
|
|
||||||
|
def backend_dummy(*example_inputs):
|
||||||
|
return [
|
||||||
|
flat_proxy_args[self.num_inputs + i]
|
||||||
|
for i in range(self.num_outputs)
|
||||||
|
]
|
||||||
|
|
||||||
|
backend_input.graph_module = backend_dummy # type: ignore[assignment]
|
||||||
|
try:
|
||||||
|
results = out.forward_callable()(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
backend_input.graph_module = backend
|
||||||
|
ret, self.out_spec = pytree.tree_flatten(results)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
out_shuffle = OutShuffle()
|
||||||
|
out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle)
|
||||||
|
|
||||||
|
def pytree_call(*args, **kwargs):
|
||||||
|
import torch.export._unlift
|
||||||
|
|
||||||
|
flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs))
|
||||||
|
if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}"
|
||||||
|
)
|
||||||
|
flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args))
|
||||||
|
assert out_shuffle.out_spec is not None
|
||||||
|
return pytree.tree_unflatten(
|
||||||
|
out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(mod, torch.nn.Module):
|
||||||
|
compiled_mod = copy.copy(mod)
|
||||||
|
compiled_mod.forward = types.MethodType(pytree_call, compiled_mod)
|
||||||
|
if not hasattr(compiled_mod, "meta"):
|
||||||
|
compiled_mod.meta = {} # type: ignore[attr-defined]
|
||||||
|
if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta:
|
||||||
|
compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode
|
||||||
|
return compiled_mod
|
||||||
|
elif inspect.ismethod(mod):
|
||||||
|
return types.MethodType(pytree_call, mod.__self__)
|
||||||
|
else:
|
||||||
|
return pytree_call
|
||||||
|
|
||||||
|
|
||||||
|
def dynamo_graph_capture_for_export(
|
||||||
|
mod: Callable[..., Any],
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
with (
|
||||||
|
get_metrics_context(),
|
||||||
|
dynamo_timed("fullgraph_capture"),
|
||||||
|
):
|
||||||
|
out = fullgraph_capture(mod, args, kwargs)
|
||||||
|
|
||||||
|
# TODO filter out side effects.
|
||||||
|
|
||||||
|
return pytreeify(out, mod, args, kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def _dynamo_graph_capture_for_export(
|
def _dynamo_graph_capture_for_export(
|
||||||
mod: Callable[..., Any],
|
mod: Callable[..., Any],
|
||||||
*,
|
*,
|
||||||
|
|
|
||||||
|
|
@ -463,6 +463,7 @@ class OutputGraphCommon(OutputGraphGuardsState):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
output_graph_guards_state: OutputGraphGuardsState,
|
output_graph_guards_state: OutputGraphGuardsState,
|
||||||
|
import_sources: Optional[dict[str, str]] = None,
|
||||||
shape_env: Optional[ShapeEnv] = None,
|
shape_env: Optional[ShapeEnv] = None,
|
||||||
export_metadata: Optional[ExportMetaData] = None,
|
export_metadata: Optional[ExportMetaData] = None,
|
||||||
tracked_fakes_id_to_source: Optional[dict[int, list[Source]]] = None,
|
tracked_fakes_id_to_source: Optional[dict[int, list[Source]]] = None,
|
||||||
|
|
@ -485,6 +486,7 @@ class OutputGraphCommon(OutputGraphGuardsState):
|
||||||
output_graph_guards_state.name_of_builtins_dict_key_in_fglobals,
|
output_graph_guards_state.name_of_builtins_dict_key_in_fglobals,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.import_sources = import_sources or {}
|
||||||
# The following fields are currently known to be used by clients.
|
# The following fields are currently known to be used by clients.
|
||||||
# In particular, we need:
|
# In particular, we need:
|
||||||
# - shape_env, for building guards
|
# - shape_env, for building guards
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user