[dynamo] Guard serialization for DUPLICATE_INPUT. (#152687)

Seems this guard is not very active. Adding a test to detect error handling at least.

Differential Revision: [D74074837](https://our.internmc.facebook.com/intern/diff/D74074837/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152687
Approved by: https://github.com/jansel
ghstack dependencies: #152615, #152616
This commit is contained in:
zhxchen17 2025-05-02 11:45:13 -07:00 committed by PyTorch MergeBot
parent ffd58293f7
commit 2cb16df6e2
2 changed files with 12 additions and 0 deletions

View File

@ -446,6 +446,16 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
with torch._functorch.eager_transforms.grad_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, False)
def test_duplicate_input(self):
def fn(x, x_):
return x + x_
x = torch.randn(3, 2)
with self.assertRaisesRegex(
RuntimeError, "DUPLICATE_INPUT guard cannot be serialized"
):
self._test_serialization("DUPLICATE_INPUT", fn, x, x)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1838,6 +1838,8 @@ class GuardBuilder(GuardBuilderBase):
# TODO(voz): Deduplicate w/ AOTAutograd dupe input guards
def DUPLICATE_INPUT(self, guard, source_b):
if self.serialization_mode == "save":
raise RuntimeError("DUPLICATE_INPUT guard cannot be serialized yet.")
ref_a = self.arg_ref(guard)
ref_b = self.arg_ref(source_b.name())