mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[dynamo] graph break on setattr requires_grad (#113163)"
This reverts commit d261687d5f.
Reverted https://github.com/pytorch/pytorch/pull/113163 on behalf of https://github.com/PaliC due to relevant tests are not running for this pr, however, this is fixed after landing https://github.com/pytorch/pytorch/pull/113297/ ([comment](https://github.com/pytorch/pytorch/pull/113163#issuecomment-1802967236))
This commit is contained in:
parent
12c257cc00
commit
94d95a91a2
|
|
@ -3576,33 +3576,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||||
compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
|
compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
|
||||||
self.assertTrue(same(out, compile_out))
|
self.assertTrue(same(out, compile_out))
|
||||||
|
|
||||||
def test_setattr_requires_grad_graph_breaks(self):
|
|
||||||
def fn(x):
|
|
||||||
z = x + 4
|
|
||||||
x.requires_grad = True
|
|
||||||
y = x * z
|
|
||||||
return y
|
|
||||||
|
|
||||||
for backend in ["count", "eager", "aot_eager"]:
|
|
||||||
if backend == "count":
|
|
||||||
backend = CompileCounter()
|
|
||||||
opt_fn = torch.compile(fn, backend=backend)
|
|
||||||
|
|
||||||
eager = torch.zeros(5)
|
|
||||||
compiled = eager.clone()
|
|
||||||
|
|
||||||
out_eager = fn(eager)
|
|
||||||
out_opt = opt_fn(compiled)
|
|
||||||
|
|
||||||
self.assertEqual(out_eager, out_opt)
|
|
||||||
|
|
||||||
out_eager.sum().backward()
|
|
||||||
out_opt.sum().backward()
|
|
||||||
|
|
||||||
self.assertEqual(eager, compiled)
|
|
||||||
if isinstance(backend, CompileCounter):
|
|
||||||
self.assertEqual(backend.frame_count, 2) # graph breaks
|
|
||||||
|
|
||||||
def test_inductor_no_recursionerror_on_for_loops(self):
|
def test_inductor_no_recursionerror_on_for_loops(self):
|
||||||
def forward(x):
|
def forward(x):
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
|
|
|
||||||
|
|
@ -1199,18 +1199,15 @@ class BuiltinVariable(VariableTracker):
|
||||||
):
|
):
|
||||||
name = name_var.as_python_constant()
|
name = name_var.as_python_constant()
|
||||||
if name == "data" and all(
|
if name == "data" and all(
|
||||||
isinstance(t, variables.TensorVariable) for t in [val, obj]
|
isinstance(t, variables.TensorVariable)
|
||||||
|
# and not (t.source is None or is_constant_source(t.source))
|
||||||
|
for t in [val, obj]
|
||||||
):
|
):
|
||||||
unimplemented(
|
unimplemented(
|
||||||
".data assignment to a tracked tensors can introduce aliasing, hence we "
|
".data assignment to a tracked tensors can introduce aliasing, hence we "
|
||||||
"need to graph break to apply the aliasing (or track new aliased tensors) "
|
"need to graph break to apply the aliasing (or track new aliased tensors) "
|
||||||
"to continue to trace the graph"
|
"to continue to trace the graph"
|
||||||
)
|
)
|
||||||
if name == "requires_grad" and isinstance(obj, variables.TensorVariable):
|
|
||||||
unimplemented(
|
|
||||||
"mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
|
|
||||||
"the middle of the graph, which aot_autograd does not currently know how to handle. "
|
|
||||||
)
|
|
||||||
tx.output.side_effects.store_attr(obj, name, val)
|
tx.output.side_effects.store_attr(obj, name, val)
|
||||||
return val
|
return val
|
||||||
elif isinstance(obj, variables.UserDefinedObjectVariable):
|
elif isinstance(obj, variables.UserDefinedObjectVariable):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user