diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index d525ef54dec..5e47b5c1c26 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -3576,33 +3576,6 @@ class ReproTests(torch._dynamo.test_case.TestCase): compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out) self.assertTrue(same(out, compile_out)) - def test_setattr_requires_grad_graph_breaks(self): - def fn(x): - z = x + 4 - x.requires_grad = True - y = x * z - return y - - for backend in ["count", "eager", "aot_eager"]: - if backend == "count": - backend = CompileCounter() - opt_fn = torch.compile(fn, backend=backend) - - eager = torch.zeros(5) - compiled = eager.clone() - - out_eager = fn(eager) - out_opt = opt_fn(compiled) - - self.assertEqual(out_eager, out_opt) - - out_eager.sum().backward() - out_opt.sum().backward() - - self.assertEqual(eager, compiled) - if isinstance(backend, CompileCounter): - self.assertEqual(backend.frame_count, 2) # graph breaks - def test_inductor_no_recursionerror_on_for_loops(self): def forward(x): for _ in range(1000): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 9892e3e1d2e..8d33360622a 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1199,18 +1199,15 @@ class BuiltinVariable(VariableTracker): ): name = name_var.as_python_constant() if name == "data" and all( - isinstance(t, variables.TensorVariable) for t in [val, obj] + isinstance(t, variables.TensorVariable) + # and not (t.source is None or is_constant_source(t.source)) + for t in [val, obj] ): unimplemented( ".data assignment to a tracked tensors can introduce aliasing, hence we " "need to graph break to apply the aliasing (or track new aliased tensors) " "to continue to trace the graph" ) - if name == "requires_grad" and isinstance(obj, variables.TensorVariable): - unimplemented( - "mutating requires_grad can introduce a new leaf from non-leaf or vice versa in " - "the middle of the graph, which aot_autograd does not currently know how to handle. " - ) tx.output.side_effects.store_attr(obj, name, val) return val elif isinstance(obj, variables.UserDefinedObjectVariable):