Revert "[dynamo] graph break on setattr requires_grad (#113163)"

This reverts commit d261687d5f.

Reverted https://github.com/pytorch/pytorch/pull/113163 on behalf of https://github.com/PaliC due to relevant tests are not running for this pr, however, this is fixed after landing https://github.com/pytorch/pytorch/pull/113297/ ([comment](https://github.com/pytorch/pytorch/pull/113163#issuecomment-1802967236))
This commit is contained in:
PyTorch MergeBot 2023-11-09 00:23:04 +00:00
parent 12c257cc00
commit 94d95a91a2
2 changed files with 3 additions and 33 deletions

View File

@ -3576,33 +3576,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out) compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
self.assertTrue(same(out, compile_out)) self.assertTrue(same(out, compile_out))
def test_setattr_requires_grad_graph_breaks(self):
def fn(x):
z = x + 4
x.requires_grad = True
y = x * z
return y
for backend in ["count", "eager", "aot_eager"]:
if backend == "count":
backend = CompileCounter()
opt_fn = torch.compile(fn, backend=backend)
eager = torch.zeros(5)
compiled = eager.clone()
out_eager = fn(eager)
out_opt = opt_fn(compiled)
self.assertEqual(out_eager, out_opt)
out_eager.sum().backward()
out_opt.sum().backward()
self.assertEqual(eager, compiled)
if isinstance(backend, CompileCounter):
self.assertEqual(backend.frame_count, 2) # graph breaks
def test_inductor_no_recursionerror_on_for_loops(self): def test_inductor_no_recursionerror_on_for_loops(self):
def forward(x): def forward(x):
for _ in range(1000): for _ in range(1000):

View File

@ -1199,18 +1199,15 @@ class BuiltinVariable(VariableTracker):
): ):
name = name_var.as_python_constant() name = name_var.as_python_constant()
if name == "data" and all( if name == "data" and all(
isinstance(t, variables.TensorVariable) for t in [val, obj] isinstance(t, variables.TensorVariable)
# and not (t.source is None or is_constant_source(t.source))
for t in [val, obj]
): ):
unimplemented( unimplemented(
".data assignment to a tracked tensors can introduce aliasing, hence we " ".data assignment to a tracked tensors can introduce aliasing, hence we "
"need to graph break to apply the aliasing (or track new aliased tensors) " "need to graph break to apply the aliasing (or track new aliased tensors) "
"to continue to trace the graph" "to continue to trace the graph"
) )
if name == "requires_grad" and isinstance(obj, variables.TensorVariable):
unimplemented(
"mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
"the middle of the graph, which aot_autograd does not currently know how to handle. "
)
tx.output.side_effects.store_attr(obj, name, val) tx.output.side_effects.store_attr(obj, name, val)
return val return val
elif isinstance(obj, variables.UserDefinedObjectVariable): elif isinstance(obj, variables.UserDefinedObjectVariable):