mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Revert "[dynamo] graph break on setattr requires_grad (#113163)"
This reverts commit d261687d5f.
Reverted https://github.com/pytorch/pytorch/pull/113163 on behalf of https://github.com/PaliC due to relevant tests are not running for this pr, however, this is fixed after landing https://github.com/pytorch/pytorch/pull/113297/ ([comment](https://github.com/pytorch/pytorch/pull/113163#issuecomment-1802967236))
This commit is contained in:
parent
12c257cc00
commit
94d95a91a2
|
|
@ -3576,33 +3576,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
|
||||
self.assertTrue(same(out, compile_out))
|
||||
|
||||
def test_setattr_requires_grad_graph_breaks(self):
|
||||
def fn(x):
|
||||
z = x + 4
|
||||
x.requires_grad = True
|
||||
y = x * z
|
||||
return y
|
||||
|
||||
for backend in ["count", "eager", "aot_eager"]:
|
||||
if backend == "count":
|
||||
backend = CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=backend)
|
||||
|
||||
eager = torch.zeros(5)
|
||||
compiled = eager.clone()
|
||||
|
||||
out_eager = fn(eager)
|
||||
out_opt = opt_fn(compiled)
|
||||
|
||||
self.assertEqual(out_eager, out_opt)
|
||||
|
||||
out_eager.sum().backward()
|
||||
out_opt.sum().backward()
|
||||
|
||||
self.assertEqual(eager, compiled)
|
||||
if isinstance(backend, CompileCounter):
|
||||
self.assertEqual(backend.frame_count, 2) # graph breaks
|
||||
|
||||
def test_inductor_no_recursionerror_on_for_loops(self):
|
||||
def forward(x):
|
||||
for _ in range(1000):
|
||||
|
|
|
|||
|
|
@ -1199,18 +1199,15 @@ class BuiltinVariable(VariableTracker):
|
|||
):
|
||||
name = name_var.as_python_constant()
|
||||
if name == "data" and all(
|
||||
isinstance(t, variables.TensorVariable) for t in [val, obj]
|
||||
isinstance(t, variables.TensorVariable)
|
||||
# and not (t.source is None or is_constant_source(t.source))
|
||||
for t in [val, obj]
|
||||
):
|
||||
unimplemented(
|
||||
".data assignment to a tracked tensors can introduce aliasing, hence we "
|
||||
"need to graph break to apply the aliasing (or track new aliased tensors) "
|
||||
"to continue to trace the graph"
|
||||
)
|
||||
if name == "requires_grad" and isinstance(obj, variables.TensorVariable):
|
||||
unimplemented(
|
||||
"mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
|
||||
"the middle of the graph, which aot_autograd does not currently know how to handle. "
|
||||
)
|
||||
tx.output.side_effects.store_attr(obj, name, val)
|
||||
return val
|
||||
elif isinstance(obj, variables.UserDefinedObjectVariable):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user