pytorch/test/dynamo/test_hooks.py
Michael Voznesensky 02f6a8126e Support a simple subset of functions as backward hooks on intermediate tensors (#109537)
The main thrust of the initial effort here was to capture `register_hook` calls on tensors in compile regions. The first part of this was done in https://github.com/pytorch/pytorch/pull/108903 wherein we added support for register_hook input tensors.

The distinction between input and intermediary is due to implementation differences.

There are 2 kinds of hooks:

1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).

Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).

**The plan:**

For tensors w/ a source: (The PR above)
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call register_hook. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke register_hook on. As long as we guard on the identity of the lifted function, this is sound to do.

For tensors w/o a source: (This PR)

Ostensibly, the most correct and complete solution would be to smuggle hooks into a runtime wrapper in aot_autograd, where all the items the hooks close over are lifted to inputs as necessary and passed alongside the user provided function. This is necessary so that we can properly trace out and capture all the mutations within the user defined hook at backwards time.

This is too complicated - so, we limited the scope of this initial PR to a simple subset of hooks:

- Hooks must have a source (be known to us already, not a lambda or intermediary defined function)
- We must be tracing under compiled autograd

**The flow**:

We use the HOP added in https://github.com/pytorch/pytorch/pull/109690/files, referred to as the HOP below.

1) We intercept register_hook calls and wrap the user defined fn in the HOP
2) We write a `_register_hook_trampoline` to the graph that is a local no-arg function that is invoked as a call_function in the dynamo graph
3) aot_autograd inlines through it during its trace, and sees the HOP
4) the HOP preserves itself in the graph - it does not get traced into
5) During backwards, compiled_autograd installs the HOP under a hook call
6) When compiled_autograd enters compilation over its generated graph, dynamo traces the contents of the hook

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109537
Approved by: https://github.com/ezyang
2023-10-11 01:35:37 +00:00

563 lines
19 KiB
Python

# Owner(s): ["module: dynamo"]
import functools
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from functorch.compile import nop
from torch._dynamo import compiled_autograd
from torch._functorch.aot_autograd import aot_module_simplified
def compiler_fn(gm):
return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm)
def global_hook_0(grad):
return grad * 4
def global_hook_1(grad):
return grad / 2
def global_hook_2(grad):
return grad * 3
h0 = None
class HooksTests(torch._dynamo.test_case.TestCase):
def test_tensor_only_register_hook_in_graph_lambda(self):
def fn(x):
x.register_hook(lambda grad: grad * 2)
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 0)
def test_tensor_register_hook_in_graph_lambda(self):
def fn(x, y, z):
x.register_hook(lambda grad: grad * 2)
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_break_handle_lambda(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
z = z * z
handle.remove()
x.register_hook(lambda grad: grad * 3)
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
self.assertEqual(cnts.frame_count, 2)
def test_tensor_register_hook_multi_handle_return(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
h2 = handle
z = z * z
return x, y * y, z, handle, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(h, None)
self.assertNotEqual(h2, None)
self.assertEqual(h2, h)
def test_tensor_register_hook_repeated_handle_return(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
h2 = handle
z = z * z
return x, y * y, z, handle, handle
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(h, None)
self.assertNotEqual(h2, None)
self.assertEqual(h2, h)
def test_tensor_register_hook_repeated_handle_not_local(self):
def fn(x, y, z, mod):
mod.handle = x.register_hook(lambda grad: grad * 2)
z = z * z
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
mod = torch.nn.Module()
mod.handle = None
v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(mod.handle, None)
def test_tensor_only_register_hook_in_graph_local(self):
def local_hook(grad):
return grad * 2
def fn(x):
x.register_hook(local_hook)
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 0)
def test_tensor_only_register_hook_in_graph_local_inner(self):
def fn(x):
def local_hook(grad):
return grad * 2
z = x * x
x.register_hook(local_hook)
z.register_hook(local_hook)
return x, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_local(self):
def local_hook(grad):
return grad * 2
def fn(x, y, z):
x.register_hook(local_hook)
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_break_handle_local(self):
def local_hook(grad):
return grad * 2
def local_hook2(grad):
return grad * 3
def fn(x, y, z):
handle = x.register_hook(local_hook)
z = z * z
handle.remove()
x.register_hook(local_hook2)
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
def test_tensor_register_global_hook(self):
def fn(x):
x.register_hook(global_hook_0)
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_multiple_hooks(self):
def fn(x):
x.register_hook(global_hook_0) # * 4
x.register_hook(global_hook_1) # / 2
x.register_hook(global_hook_2) # * 3
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_multiple_hooks_handles_in_list(self):
def fn(x):
h0 = x.register_hook(global_hook_0) # * 4
h1 = x.register_hook(global_hook_1) # / 2
h2 = x.register_hook(global_hook_2) # * 3
return x, x * x, h0, h1, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r, handle_0, handle_1, handle_2 = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
handle_0.remove()
handle_1.remove()
handle_2.remove()
v.backward(torch.tensor([1.0, 2.0, 3.0]))
# Handles gone, grad is just applied as is
self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_global_hooks_handles_in_list(self):
def fn(x):
global h0
h0 = x.register_hook(global_hook_0) # * 4
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r = fn(v)
self.assertIsNotNone(h0)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
h0.remove()
v.backward(torch.tensor([1.0, 2.0, 3.0]))
# Handles gone, grad is just applied as is
self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0]))
# NYI!
self.assertEqual(cnts.frame_count, 0)
def test_intermediary_hooks(self):
# Graph breaks because compiled_autograd is not set
def simple_hook(g):
return g * 2
def f(x):
y = x + 1
y.register_hook(simple_hook)
z = y + 1
return z
out = torch.randn(1, requires_grad=True)
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts, nopython=False)(f)
res = fn(out)
res.backward()
self.assertEqual(res, f(out))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(out.grad, torch.Tensor([2.0]))
def test_intermediary_hooks_same_on_aot_eager(self):
def my_hook(grad, *, k=0):
return grad + k
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
hook1 = functools.partial(my_hook, k=3)
hook2 = functools.partial(my_hook, k=4)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_input_hooks_same(self):
backends = ["eager", "aot_eager", "inductor"]
for backend in backends:
def my_hook(grad, *, k=0):
return grad + k
hook = functools.partial(my_hook, k=3)
class MyMod(torch.nn.Module):
def forward(self, x):
x.register_hook(hook)
y = x.mul(2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2)
with compiled_autograd.enable(compiler_fn):
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_intermediary_hooks_same_on_inductor(self):
def my_hook(grad, *, k=0):
return grad + k
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
hook1 = functools.partial(my_hook, k=3)
hook2 = functools.partial(my_hook, k=4)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor(self):
class SomePyClass:
count = 0
def do_stuff(self, grad):
if self.count % 2 == 0:
r = grad * grad
else:
r = grad + grad
self.count += 1
return r
def complex_state_touching_hook(grad, *, obj):
return obj.do_stuff(grad)
class MyMod(torch.nn.Module):
def forward(self, x, obj):
y = x.mul(2)
hook1 = functools.partial(complex_state_touching_hook, obj=obj)
hook2 = functools.partial(complex_state_touching_hook, obj=obj)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
obj = SomePyClass()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0, obj)
eager_out[0].backward(torch.ones(4))
# Eager 2
self.assertEqual(obj.count, 2)
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, eager_out)
# Eager 2 + compiled 2
self.assertEqual(obj.count, 4)
self.assertEqual(x0.grad, x2.grad)
def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor_with_graph_break(
self,
):
class SomePyClass:
grad_as_str = "None"
count = 0
def write_grad_as_str_and_do_stuff(self, grad):
self.grad_as_str = str(grad)
if self.count % 2 == 0:
r = grad * grad
else:
r = grad + grad
print("Break!")
self.count += 1
return r
def complex_state_touching_hook(grad, *, obj):
return obj.write_grad_as_str_and_do_stuff(grad)
class MyMod(torch.nn.Module):
def forward(self, x, obj):
y = x.mul(2)
hook1 = functools.partial(complex_state_touching_hook, obj=obj)
hook2 = functools.partial(complex_state_touching_hook, obj=obj)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
obj = SomePyClass()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0, obj)
eager_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, ".*BuiltinVariable\\(str\\).*"
):
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(obj.count, 2)
def test_no_recompile_on_hook_identity_change(self):
def my_hook(grad, k=0):
return grad + k
def my_hook2(grad):
return grad * 2
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
y.register_hook(my_hook)
y.register_hook(my_hook)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod)
comp_out = comp_mod(x1)
comp_out[0].backward(torch.ones(4))
self.assertEqual(cnts.frame_count, 1)
my_hook = my_hook2 # noqa: F811
self.assertEqual(x0.grad, x1.grad)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
comp_out = comp_mod(x1)
self.assertEqual(cnts.frame_count, 2)
comp_out[0].backward(torch.ones(4))
self.assertEqual(x0.grad, x1.grad)
def test_functools_arg_vary(self):
def pre_hook(grad, *, k):
return grad * k
hook = functools.partial(pre_hook, k=1)
@torch.compile(backend="eager", fullgraph=True)
def h(x):
y = x.mul(2)
y.register_hook(hook)
return y.mul(3)
with compiled_autograd.enable(torch.compile(backend="eager", fullgraph=True)):
x = torch.randn(2, requires_grad=True)
h(x).sum().backward()
orig_grad = x.grad
x.grad = None
hook = functools.partial(pre_hook, k=2)
h(x).sum().backward()
self.assertEqual(orig_grad * 2, x.grad)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()