mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Guard serialization for DUPLICATE_INPUT. (#152687)
Seems this guard is not very active. Adding a test to detect error handling at least. Differential Revision: [D74074837](https://our.internmc.facebook.com/intern/diff/D74074837/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152687 Approved by: https://github.com/jansel ghstack dependencies: #152615, #152616
This commit is contained in:
parent
ffd58293f7
commit
2cb16df6e2
|
|
@ -446,6 +446,16 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
def test_duplicate_input(self):
|
||||
def fn(x, x_):
|
||||
return x + x_
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "DUPLICATE_INPUT guard cannot be serialized"
|
||||
):
|
||||
self._test_serialization("DUPLICATE_INPUT", fn, x, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1838,6 +1838,8 @@ class GuardBuilder(GuardBuilderBase):
|
|||
|
||||
# TODO(voz): Deduplicate w/ AOTAutograd dupe input guards
|
||||
def DUPLICATE_INPUT(self, guard, source_b):
|
||||
if self.serialization_mode == "save":
|
||||
raise RuntimeError("DUPLICATE_INPUT guard cannot be serialized yet.")
|
||||
ref_a = self.arg_ref(guard)
|
||||
ref_b = self.arg_ref(source_b.name())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user