pytorch/test/dynamo/test_hooks.py
Michael Voznesensky 064ae9ff33 Support register_hook on input tensors (#108903)
The strategy in this PR is pretty straightforward.

There are 2 kinds of hooks:

1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).

Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).

The plan:

**For tensors w/ a source:**
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do.

**For tensors w/o a source:**
Graph break - we will support this in a subsequent PR

**Handles:**

An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903
Approved by: https://github.com/ezyang
ghstack dependencies: #108846, #109092
2023-09-14 01:52:21 +00:00

402 lines
14 KiB
Python

# Owner(s): ["module: dynamo"]
import functools
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from functorch.compile import nop
from torch._functorch.aot_autograd import aot_module_simplified
def global_hook_0(grad):
return grad * 4
def global_hook_1(grad):
return grad / 2
def global_hook_2(grad):
return grad * 3
h0 = None
class HooksTests(torch._dynamo.test_case.TestCase):
def test_tensor_only_register_hook_in_graph_lambda(self):
def fn(x):
x.register_hook(lambda grad: grad * 2)
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 0)
def test_tensor_register_hook_in_graph_lambda(self):
def fn(x, y, z):
x.register_hook(lambda grad: grad * 2)
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_break_handle_lambda(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
z = z * z
handle.remove()
x.register_hook(lambda grad: grad * 3)
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
self.assertEqual(cnts.frame_count, 2)
def test_tensor_register_hook_multi_handle_return(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
h2 = handle
z = z * z
return x, y * y, z, handle, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(h, None)
self.assertNotEqual(h2, None)
self.assertEqual(h2, h)
def test_tensor_register_hook_repeated_handle_return(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
h2 = handle
z = z * z
return x, y * y, z, handle, handle
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(h, None)
self.assertNotEqual(h2, None)
self.assertEqual(h2, h)
def test_tensor_register_hook_repeated_handle_not_local(self):
def fn(x, y, z, mod):
mod.handle = x.register_hook(lambda grad: grad * 2)
z = z * z
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
mod = torch.nn.Module()
mod.handle = None
v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(mod.handle, None)
def test_tensor_only_register_hook_in_graph_local(self):
def local_hook(grad):
return grad * 2
def fn(x):
x.register_hook(local_hook)
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 0)
def test_tensor_only_register_hook_in_graph_local_inner(self):
def fn(x):
def local_hook(grad):
return grad * 2
z = x * x
x.register_hook(local_hook)
z.register_hook(local_hook)
return x, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_local(self):
def local_hook(grad):
return grad * 2
def fn(x, y, z):
x.register_hook(local_hook)
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_break_handle_local(self):
def local_hook(grad):
return grad * 2
def local_hook2(grad):
return grad * 3
def fn(x, y, z):
handle = x.register_hook(local_hook)
z = z * z
handle.remove()
x.register_hook(local_hook2)
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
def test_tensor_register_global_hook(self):
def fn(x):
x.register_hook(global_hook_0)
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_multiple_hooks(self):
def fn(x):
x.register_hook(global_hook_0) # * 4
x.register_hook(global_hook_1) # / 2
x.register_hook(global_hook_2) # * 3
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_multiple_hooks_handles_in_list(self):
def fn(x):
h0 = x.register_hook(global_hook_0) # * 4
h1 = x.register_hook(global_hook_1) # / 2
h2 = x.register_hook(global_hook_2) # * 3
return x, x * x, h0, h1, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r, handle_0, handle_1, handle_2 = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
handle_0.remove()
handle_1.remove()
handle_2.remove()
v.backward(torch.tensor([1.0, 2.0, 3.0]))
# Handles gone, grad is just applied as is
self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_global_hooks_handles_in_list(self):
def fn(x):
global h0
h0 = x.register_hook(global_hook_0) # * 4
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r = fn(v)
self.assertIsNotNone(h0)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
h0.remove()
v.backward(torch.tensor([1.0, 2.0, 3.0]))
# Handles gone, grad is just applied as is
self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0]))
# NYI!
self.assertEqual(cnts.frame_count, 0)
def test_intermediary_hooks(self):
def simple_hook(g):
return g * 2
def f(x):
y = x + 1
y.register_hook(simple_hook)
z = y + 1
return z
out = torch.randn(1, requires_grad=True)
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts, nopython=False)(f)
res = fn(out)
res.backward()
self.assertEqual(res, f(out))
# Will be 1 when we support hooks on intermediaries
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(out.grad, torch.Tensor([2.0]))
def test_intermediary_hooks_same_on_aot_eager(self):
def my_hook(grad, *, k=0):
return grad + k
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
hook1 = functools.partial(my_hook, k=3)
hook2 = functools.partial(my_hook, k=4)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
dynamo_out = torch._dynamo.optimize("aot_eager")(mod)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_intermediary_hooks_same_on_inductor(self):
def my_hook(grad, k=0):
return grad + k
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
hook1 = functools.partial(my_hook, k=3)
hook2 = functools.partial(my_hook, k=4)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
dynamo_out = torch._dynamo.optimize("inductor")(mod)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_no_recompile_on_hook_identity_change(self):
def my_hook(grad, k=0):
return grad + k
def my_hook2(grad):
return grad * 2
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
y.register_hook(my_hook)
y.register_hook(my_hook)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
comp_mod = torch._dynamo.optimize(cnts)(mod)
comp_out = comp_mod(x1)
comp_out[0].backward(torch.ones(4))
# Will be 1 when we support hooks on intermediaries
self.assertEqual(cnts.frame_count, 2)
my_hook = my_hook2 # noqa: F811
self.assertEqual(x0.grad, x1.grad)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
comp_out = comp_mod(x1)
# Will be 2 when we support hooks on intermediaries
self.assertEqual(cnts.frame_count, 4)
comp_out[0].backward(torch.ones(4))
self.assertEqual(x0.grad, x1.grad)