mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Dynamo] Fix bug: GradMode doesn't carry grad state correctly after graph break (#88537)
Fixes https://github.com/pytorch/torchdynamo/issues/1446 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88537 Approved by: https://github.com/jansel
This commit is contained in:
parent
6663ae5537
commit
bd1ffc6501
|
|
@ -1271,6 +1271,21 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x)
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
# https://github.com/pytorch/torchdynamo/issues/1446
|
||||
def test_grad_mode_carrying_correct_state_after_graph_break(self):
|
||||
def fn(x):
|
||||
with torch.no_grad():
|
||||
y = x * 3
|
||||
print("Break")
|
||||
z = x + 2
|
||||
return y, z
|
||||
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
opt_fn = torch._dynamo.optimize("eager")(fn)
|
||||
y, z = opt_fn(x)
|
||||
self.assertFalse(y.requires_grad)
|
||||
self.assertFalse(z.requires_grad)
|
||||
|
||||
def test_abc_setattr(self):
|
||||
# tests that we correctly bail out of __setattr__ calls
|
||||
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ class GradModeVariable(ContextWrappingVariable):
|
|||
return "_C._set_grad_enabled"
|
||||
|
||||
def fn_name(self):
|
||||
if self.target_values:
|
||||
if self.target_values[0]:
|
||||
return "enable_grad"
|
||||
else:
|
||||
return "no_grad"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user