[Dynamo] Fix bug: GradMode doesn't carry grad state correctly after graph break (#88537)

Fixes https://github.com/pytorch/torchdynamo/issues/1446

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88537
Approved by: https://github.com/jansel
This commit is contained in:
Yanbo Liang 2022-11-07 18:03:31 +00:00 committed by PyTorch MergeBot
parent 6663ae5537
commit bd1ffc6501
2 changed files with 16 additions and 1 deletions

View File

@ -1271,6 +1271,21 @@ class ReproTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertTrue(same(ref, res))
# https://github.com/pytorch/torchdynamo/issues/1446
def test_grad_mode_carrying_correct_state_after_graph_break(self):
def fn(x):
with torch.no_grad():
y = x * 3
print("Break")
z = x + 2
return y, z
x = torch.randn(3, requires_grad=True)
opt_fn = torch._dynamo.optimize("eager")(fn)
y, z = opt_fn(x)
self.assertFalse(y.requires_grad)
self.assertFalse(z.requires_grad)
def test_abc_setattr(self):
# tests that we correctly bail out of __setattr__ calls

View File

@ -283,7 +283,7 @@ class GradModeVariable(ContextWrappingVariable):
return "_C._set_grad_enabled"
def fn_name(self):
if self.target_values:
if self.target_values[0]:
return "enable_grad"
else:
return "no_grad"