mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Guard serialization for DUAL LEVEL. (#152615)
Seem dual level counter should be stored in OutputGraph so that the value can be preserved through roundtripping. Differential Revision: [D74008786](https://our.internmc.facebook.com/intern/diff/D74008786/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152615 Approved by: https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
parent
0145f9e29e
commit
1d1cbcd8a3
|
|
@ -314,6 +314,18 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
finally:
|
||||
op.__name__ = prev
|
||||
|
||||
def test_dual_level(self):
|
||||
def fn(x):
|
||||
with torch.autograd.forward_ad.dual_level():
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
ref, loaded = self._test_serialization("DUAL_LEVEL", fn, x)
|
||||
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
with torch.autograd.forward_ad.dual_level():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1575,7 +1575,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
def DUAL_LEVEL(self, guard: Guard):
|
||||
# Invalidate dual level if current dual level is different than the one
|
||||
# in the fx graph
|
||||
dual_level = torch.autograd.forward_ad._current_level
|
||||
dual_level = self.check_fn_manager.output_graph.dual_level
|
||||
code = [f"torch.autograd.forward_ad._current_level == {dual_level}"]
|
||||
self._set_guard_export_info(guard, [code])
|
||||
# TODO(anijain2305) - Consider this moving this guard to C++
|
||||
|
|
|
|||
|
|
@ -300,6 +300,8 @@ class OutputGraphGuardsState:
|
|||
guard_on_key_order: set[Source]
|
||||
# Map from graph input's `Source` to sizes / strides metadata
|
||||
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
|
||||
dual_level: int
|
||||
|
||||
export: bool = False
|
||||
export_constraints: bool = False
|
||||
|
||||
|
|
@ -351,6 +353,7 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
torch_function_mode_stack,
|
||||
guard_on_key_order=set(),
|
||||
input_source_to_sizes_strides={},
|
||||
dual_level=torch.autograd.forward_ad._current_level,
|
||||
)
|
||||
self.tracers = [SubgraphTracer(self, is_export=export)]
|
||||
# Map from graph input's `Source` to its `VariableTracker` to
|
||||
|
|
@ -586,6 +589,7 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
torch_function_mode_stack=self.torch_function_mode_stack,
|
||||
guard_on_key_order=self.guard_on_key_order,
|
||||
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
|
||||
dual_level=self.dual_level,
|
||||
export=self.export,
|
||||
export_constraints=self.export_constraints,
|
||||
_guards=self.guards,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user