From 94d95a91a253f19e22820912e71dba75361d4e5c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 9 Nov 2023 00:23:04 +0000 Subject: [PATCH] Revert "[dynamo] graph break on setattr `requires_grad` (#113163)" This reverts commit d261687d5f56ac8148fab2567cf1fa6dd5264def. Reverted https://github.com/pytorch/pytorch/pull/113163 on behalf of https://github.com/PaliC due to relevant tests are not running for this pr, however, this is fixed after landing https://github.com/pytorch/pytorch/pull/113297/ ([comment](https://github.com/pytorch/pytorch/pull/113163#issuecomment-1802967236)) --- test/dynamo/test_repros.py | 27 --------------------------- torch/_dynamo/variables/builtin.py | 9 +++------ 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index d525ef54dec..5e47b5c1c26 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -3576,33 +3576,6 @@ class ReproTests(torch._dynamo.test_case.TestCase): compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out) self.assertTrue(same(out, compile_out)) - def test_setattr_requires_grad_graph_breaks(self): - def fn(x): - z = x + 4 - x.requires_grad = True - y = x * z - return y - - for backend in ["count", "eager", "aot_eager"]: - if backend == "count": - backend = CompileCounter() - opt_fn = torch.compile(fn, backend=backend) - - eager = torch.zeros(5) - compiled = eager.clone() - - out_eager = fn(eager) - out_opt = opt_fn(compiled) - - self.assertEqual(out_eager, out_opt) - - out_eager.sum().backward() - out_opt.sum().backward() - - self.assertEqual(eager, compiled) - if isinstance(backend, CompileCounter): - self.assertEqual(backend.frame_count, 2) # graph breaks - def test_inductor_no_recursionerror_on_for_loops(self): def forward(x): for _ in range(1000): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 9892e3e1d2e..8d33360622a 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1199,18 +1199,15 @@ class BuiltinVariable(VariableTracker): ): name = name_var.as_python_constant() if name == "data" and all( - isinstance(t, variables.TensorVariable) for t in [val, obj] + isinstance(t, variables.TensorVariable) + # and not (t.source is None or is_constant_source(t.source)) + for t in [val, obj] ): unimplemented( ".data assignment to a tracked tensors can introduce aliasing, hence we " "need to graph break to apply the aliasing (or track new aliased tensors) " "to continue to trace the graph" ) - if name == "requires_grad" and isinstance(obj, variables.TensorVariable): - unimplemented( - "mutating requires_grad can introduce a new leaf from non-leaf or vice versa in " - "the middle of the graph, which aot_autograd does not currently know how to handle. " - ) tx.output.side_effects.store_attr(obj, name, val) return val elif isinstance(obj, variables.UserDefinedObjectVariable):