[dynamo] Guard serialization for DUAL LEVEL. (#152615)

Seem dual level counter should be stored in OutputGraph so that the value can be preserved through roundtripping.

Differential Revision: [D74008786](https://our.internmc.facebook.com/intern/diff/D74008786/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152615
Approved by: https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
zhxchen17 2025-05-01 09:50:34 -07:00 committed by PyTorch MergeBot
parent 0145f9e29e
commit 1d1cbcd8a3
3 changed files with 17 additions and 1 deletions

View File

@ -314,6 +314,18 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
finally:
op.__name__ = prev
def test_dual_level(self):
def fn(x):
with torch.autograd.forward_ad.dual_level():
return x + 1
x = torch.randn(3)
ref, loaded = self._test_serialization("DUAL_LEVEL", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
with torch.autograd.forward_ad.dual_level():
self._test_check_fn(ref, loaded, {"x": x}, False)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1575,7 +1575,7 @@ class GuardBuilder(GuardBuilderBase):
def DUAL_LEVEL(self, guard: Guard):
# Invalidate dual level if current dual level is different than the one
# in the fx graph
dual_level = torch.autograd.forward_ad._current_level
dual_level = self.check_fn_manager.output_graph.dual_level
code = [f"torch.autograd.forward_ad._current_level == {dual_level}"]
self._set_guard_export_info(guard, [code])
# TODO(anijain2305) - Consider this moving this guard to C++

View File

@ -300,6 +300,8 @@ class OutputGraphGuardsState:
guard_on_key_order: set[Source]
# Map from graph input's `Source` to sizes / strides metadata
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
dual_level: int
export: bool = False
export_constraints: bool = False
@ -351,6 +353,7 @@ class OutputGraph(OutputGraphGuardsState):
torch_function_mode_stack,
guard_on_key_order=set(),
input_source_to_sizes_strides={},
dual_level=torch.autograd.forward_ad._current_level,
)
self.tracers = [SubgraphTracer(self, is_export=export)]
# Map from graph input's `Source` to its `VariableTracker` to
@ -586,6 +589,7 @@ class OutputGraph(OutputGraphGuardsState):
torch_function_mode_stack=self.torch_function_mode_stack,
guard_on_key_order=self.guard_on_key_order,
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
dual_level=self.dual_level,
export=self.export,
export_constraints=self.export_constraints,
_guards=self.guards,