# Owner(s): ["module: dynamo"] import functools import torch import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing from functorch.compile import nop from torch._functorch.aot_autograd import aot_module_simplified def global_hook_0(grad): return grad * 4 def global_hook_1(grad): return grad / 2 def global_hook_2(grad): return grad * 3 h0 = None class HooksTests(torch._dynamo.test_case.TestCase): def test_tensor_only_register_hook_in_graph_lambda(self): def fn(x): x.register_hook(lambda grad: grad * 2) return x cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v) v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 0) def test_tensor_register_hook_in_graph_lambda(self): def fn(x, y, z): x.register_hook(lambda grad: grad * 2) return x, y * y, z * z cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 1) def test_tensor_register_hook_in_graph_break_handle_lambda(self): def fn(x, y, z): handle = x.register_hook(lambda grad: grad * 2) z = z * z handle.remove() x.register_hook(lambda grad: grad * 3) return x, y * y, z cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0])) self.assertEqual(cnts.frame_count, 2) def test_tensor_register_hook_multi_handle_return(self): def fn(x, y, z): handle = x.register_hook(lambda grad: grad * 2) h2 = handle z = z * z return x, y * y, z, handle, h2 cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 1) self.assertNotEqual(h, None) self.assertNotEqual(h2, None) self.assertEqual(h2, h) def test_tensor_register_hook_repeated_handle_return(self): def fn(x, y, z): handle = x.register_hook(lambda grad: grad * 2) h2 = handle z = z * z return x, y * y, z, handle, handle cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 1) self.assertNotEqual(h, None) self.assertNotEqual(h2, None) self.assertEqual(h2, h) def test_tensor_register_hook_repeated_handle_not_local(self): def fn(x, y, z, mod): mod.handle = x.register_hook(lambda grad: grad * 2) z = z * z return x, y * y, z cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) mod = torch.nn.Module() mod.handle = None v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod) v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 1) self.assertNotEqual(mod.handle, None) def test_tensor_only_register_hook_in_graph_local(self): def local_hook(grad): return grad * 2 def fn(x): x.register_hook(local_hook) return x cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v) v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 0) def test_tensor_only_register_hook_in_graph_local_inner(self): def fn(x): def local_hook(grad): return grad * 2 z = x * x x.register_hook(local_hook) z.register_hook(local_hook) return x, z cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v) v[0].backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 1) def test_tensor_register_hook_in_graph_local(self): def local_hook(grad): return grad * 2 def fn(x, y, z): x.register_hook(local_hook) return x, y * y, z * z cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 1) def test_tensor_register_hook_in_graph_break_handle_local(self): def local_hook(grad): return grad * 2 def local_hook2(grad): return grad * 3 def fn(x, y, z): handle = x.register_hook(local_hook) z = z * z handle.remove() x.register_hook(local_hook2) return x, y * y, z cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0])) def test_tensor_register_global_hook(self): def fn(x): x.register_hook(global_hook_0) return x, x * x cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v)[0] v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0])) self.assertEqual(cnts.frame_count, 1) def test_tensor_register_multiple_hooks(self): def fn(x): x.register_hook(global_hook_0) # * 4 x.register_hook(global_hook_1) # / 2 x.register_hook(global_hook_2) # * 3 return x, x * x cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v = fn(v)[0] v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0])) self.assertEqual(cnts.frame_count, 1) def test_tensor_register_multiple_hooks_handles_in_list(self): def fn(x): h0 = x.register_hook(global_hook_0) # * 4 h1 = x.register_hook(global_hook_1) # / 2 h2 = x.register_hook(global_hook_2) # * 3 return x, x * x, h0, h1, h2 cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v, r, handle_0, handle_1, handle_2 = fn(v) v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0])) handle_0.remove() handle_1.remove() handle_2.remove() v.backward(torch.tensor([1.0, 2.0, 3.0])) # Handles gone, grad is just applied as is self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0])) self.assertEqual(cnts.frame_count, 1) def test_tensor_register_global_hooks_handles_in_list(self): def fn(x): global h0 h0 = x.register_hook(global_hook_0) # * 4 return x, x * x cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) v, r = fn(v) self.assertIsNotNone(h0) v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0])) h0.remove() v.backward(torch.tensor([1.0, 2.0, 3.0])) # Handles gone, grad is just applied as is self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0])) # NYI! self.assertEqual(cnts.frame_count, 0) def test_intermediary_hooks(self): def simple_hook(g): return g * 2 def f(x): y = x + 1 y.register_hook(simple_hook) z = y + 1 return z out = torch.randn(1, requires_grad=True) cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts, nopython=False)(f) res = fn(out) res.backward() self.assertEqual(res, f(out)) # Will be 1 when we support hooks on intermediaries self.assertEqual(cnts.frame_count, 2) self.assertEqual(out.grad, torch.Tensor([2.0])) def test_intermediary_hooks_same_on_aot_eager(self): def my_hook(grad, *, k=0): return grad + k class MyMod(torch.nn.Module): def forward(self, x): y = x.mul(2) hook1 = functools.partial(my_hook, k=3) hook2 = functools.partial(my_hook, k=4) y.register_hook(hook1) y.register_hook(hook2) z = y.mul(3) return (z,) mod = MyMod() x0 = torch.ones(4, requires_grad=True) eager_out = mod(x0) eager_out[0].backward(torch.ones(4)) x1 = torch.ones(4, requires_grad=True) mod_compiled = aot_module_simplified(mod, (x1,), nop) aot_out = mod_compiled(x1) aot_out[0].backward(torch.ones(4)) x2 = torch.ones(4, requires_grad=True) dynamo_out = torch._dynamo.optimize("aot_eager")(mod)(x2) dynamo_out[0].backward(torch.ones(4)) self.assertEqual(dynamo_out, aot_out) self.assertEqual(dynamo_out, eager_out) self.assertEqual(x0.grad, x1.grad) self.assertEqual(x0.grad, x2.grad) def test_intermediary_hooks_same_on_inductor(self): def my_hook(grad, k=0): return grad + k class MyMod(torch.nn.Module): def forward(self, x): y = x.mul(2) hook1 = functools.partial(my_hook, k=3) hook2 = functools.partial(my_hook, k=4) y.register_hook(hook1) y.register_hook(hook2) z = y.mul(3) return (z,) mod = MyMod() x0 = torch.ones(4, requires_grad=True) eager_out = mod(x0) eager_out[0].backward(torch.ones(4)) x1 = torch.ones(4, requires_grad=True) mod_compiled = aot_module_simplified(mod, (x1,), nop) aot_out = mod_compiled(x1) aot_out[0].backward(torch.ones(4)) x2 = torch.ones(4, requires_grad=True) dynamo_out = torch._dynamo.optimize("inductor")(mod)(x2) dynamo_out[0].backward(torch.ones(4)) self.assertEqual(dynamo_out, aot_out) self.assertEqual(dynamo_out, eager_out) self.assertEqual(x0.grad, x1.grad) self.assertEqual(x0.grad, x2.grad) def test_no_recompile_on_hook_identity_change(self): def my_hook(grad, k=0): return grad + k def my_hook2(grad): return grad * 2 class MyMod(torch.nn.Module): def forward(self, x): y = x.mul(2) y.register_hook(my_hook) y.register_hook(my_hook) z = y.mul(3) return (z,) mod = MyMod() x0 = torch.ones(4, requires_grad=True) eager_out = mod(x0) eager_out[0].backward(torch.ones(4)) x1 = torch.ones(4, requires_grad=True) cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") comp_mod = torch._dynamo.optimize(cnts)(mod) comp_out = comp_mod(x1) comp_out[0].backward(torch.ones(4)) # Will be 1 when we support hooks on intermediaries self.assertEqual(cnts.frame_count, 2) my_hook = my_hook2 # noqa: F811 self.assertEqual(x0.grad, x1.grad) eager_out = mod(x0) eager_out[0].backward(torch.ones(4)) comp_out = comp_mod(x1) # Will be 2 when we support hooks on intermediaries self.assertEqual(cnts.frame_count, 4) comp_out[0].backward(torch.ones(4)) self.assertEqual(x0.grad, x1.grad)