pytorch/test/inductor/test_distributed_patterns.py
Simon Fan d274d57037 [compiled autograd][dynamo] Make compiled graph take in boxed inputs (#122353)
### Context
In today's Dynamo, we lift all tensors encountered during tracing to be individual graph inputs, even when they were in a container.

And [Dynamo generates](fdc281f258/torch/_dynamo/codegen.py (L371)) the runtime function's signature using the graph's graphargs.

This means that the generated function will have each grapharg as an argument, which is problematic if we want to free the inputs in inductor codegen. See [python function arguments are kept alive for the duration of the function call](https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670).

```python
# original code
def forward(inputs):
  a, b, c, d, e = inputs
  inputs.clear()
  out = a
  out += b
  del b  # frees memory
  out += c
  del c  # frees memory
  out += d
  del d  # frees memory
  out += e
  del e  # frees memory
  return out

# compiled code:
def forward(a, b, c, d, e):
  # b, c, d, e can't be freed before end of function
```

This isn't a concern when compiling forward because a, b, c, d, e are all from user code, and should be kept alive. But when compiling backwards, a, b, c, d, e may be intermediate results i.e. activations, that we DO want to clear ASAP to remain on par with eager peak memory.

### Solution

We have encountered similar memory problems in AOTAutograd before, where we adopted the boxed calling convention (wrapping to-be-freed objects in a list), adding list clearing to inductor codegen, and being careful about holding references to elements in the input list. We need to do something similar, but for inputs from the user program (compiled autograd fx graph in this case).

This PR support lists as graphargs/placeholder nodes. When tracing a list of tensors, we create a node for it, and pre-emptively initialize variable trackers for its elements before they are used in the user program. Subsequent uses of those variables will find hits in the lookup table `input_source_to_var`.

With the inputs as a list in the graph args, our compiled code can free inputs just like in the eager case.
```python
def forward(inputs):
  # a, b, c, d, e can be freed within the function now
```

Currently, AOT/Inductor flattens list input via [flatten_graph_inputs wrapper](597f479643/torch/_inductor/compile_fx.py (L1454-L1478)), which is why this PR's CI can be green. Additional changes are needed to its runtime wrapper, done in the next PR. The next step is to ensure that we are careful in forwarding the list to inductor codegen without holding additional references.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122353
Approved by: https://github.com/jansel
ghstack dependencies: #123630, #123674
2024-04-12 10:29:09 +00:00

451 lines
15 KiB
Python

# Owner(s): ["oncall: pt2"]
import dataclasses
import functools
import torch
from torch import nn
from torch._dynamo import compiled_autograd
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.inductor_utils import HAS_CPU
def init_fake_distributed():
# Fake distributed
WORLD_SIZE = 2
# TODO(jansel): fix support for this
RESIZE = False
@torch.no_grad
def all_gather(t):
return torch.cat([t] * WORLD_SIZE, 0)
@torch.no_grad
def reduce_scatter(t):
return t.narrow(0, 0, t.size(0) // WORLD_SIZE)
def fw_pre_hook(mod, inp):
with torch.no_grad():
mod.og_weight = mod.weight
mod.weight = nn.Parameter(all_gather(mod.weight))
def fw_post_hook(mod, inp, out):
if RESIZE:
# Drop the big weight
mod.weight.untyped_storage().resize_(0)
mod.empty_weight = mod.weight
mod.weight = mod.og_weight
del mod.og_weight
def bw_pre_hook(mod, gO):
if RESIZE:
mod.empty_weight.untyped_storage().resize_(
WORLD_SIZE * mod.weight.nelement() * mod.weight.element_size()
)
mod.og_weight = mod.weight
full_weight = nn.Parameter(all_gather(mod.weight))
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.empty_weight
):
mod.empty_weight.copy_(full_weight)
mod.weight = mod.empty_weight
del mod.empty_weight
def bw_post_hook(mod, gI, gO):
grad = mod.weight.grad
new_grad = reduce_scatter(grad)
# No need to re-empty the weight here, the graph has been cleared
# This removes the last reference to the big Tensor
mod.weight = mod.og_weight
del mod.og_weight
mod.weight.grad = new_grad
torch.manual_seed(1234)
m = nn.Linear(20, 10, bias=False)
m.weight = nn.Parameter(reduce_scatter(m.weight))
m.register_full_backward_pre_hook(bw_pre_hook)
m.register_full_backward_hook(bw_post_hook)
m.register_forward_pre_hook(fw_pre_hook)
m.register_forward_hook(fw_post_hook)
return m, torch.rand(2, 20, requires_grad=True)
def init_module_bw_hooks(allow_eager):
def bw_pre_hook(mod, gO):
assert allow_eager or torch._dynamo.is_compiling()
assert mod.weight.size() == (10, 10)
mod.hook_count_pre.add_(1)
return (torch.sin(gO[0] + 1.2),)
def bw_post_hook(mod, gI, gO):
assert allow_eager or torch._dynamo.is_compiling()
assert mod.weight.size() == (10, 10)
mod.hook_count_post.add_(1)
return (torch.sin(gI[0] + 3.4),)
torch.manual_seed(1234)
m = nn.Linear(10, 10)
m.hook_count_pre = torch.tensor(0)
m.hook_count_post = torch.tensor(0)
m.register_full_backward_pre_hook(bw_pre_hook)
m.register_full_backward_hook(bw_post_hook)
return m, torch.rand(2, 10, requires_grad=True)
def steps(m, inp):
for _ in range(4):
out = m(inp)
out.sum().backward()
return out
class DistributedPatternTests(TestCase):
def test_intermediate_hook_with_closure(self):
@dataclasses.dataclass
class CustomObj:
val: torch.Tensor
def fn(x, obj):
y = x.sin()
closure_var = y + 1
y.register_hook(lambda grad: grad + obj.val + closure_var)
z = y.sin()
return z
opt = torch.compile(fn, fullgraph=True)
obj1 = CustomObj(torch.tensor(88))
obj2 = CustomObj(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd.enable(functools.partial(torch.compile, fullgraph=True)):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
@torch.no_grad()
def test_storage_resize_zero(self):
@torch.compile(fullgraph=True)
def fn(x):
y = torch.sin(x)
x.untyped_storage().resize_(0)
return torch.cos(y)
x = torch.randn(10)
expected = torch.cos(torch.sin(x))
y = fn(x)
self.assertEqual(y, expected)
self.assertEqual(x.untyped_storage().size(), 0)
@torch.no_grad()
def test_storage_resize_nonzero(self):
@torch.compile(fullgraph=True)
def fn(x, out):
y = torch.sin(x)
assert out.untyped_storage().size() == 0
out.untyped_storage().resize_(x.untyped_storage().size())
out.copy_(y.cos())
x = torch.randn(10)
out = torch.randn(10)
expected = torch.cos(torch.sin(x))
out.untyped_storage().resize_(0)
fn(x, out)
self.assertEqual(out.untyped_storage().size(), x.untyped_storage().size())
self.assertEqual(out, expected)
@torch.no_grad()
def test_unsafe_set_version_counter1(self):
cnt = CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(w, x):
x = x.sin()
v = w._version
w.copy_(x + 1)
torch._C._autograd._unsafe_set_version_counter(w, v)
return w, v
for v in (3, 0, 1):
w1 = torch.randn(16)
for i in range(v):
w1.fill_(i) # bump w1._version
self.assertEqual(w1._version, v)
x1 = torch.randn(16)
w2, v2 = fn(w1, x1)
self.assertIs(w1, w2)
self.assertEqual(w1, x1.sin() + 1)
self.assertEqual(v2, v)
self.assertEqual(w1._version, v)
self.assertEqual(cnt.frame_count, 1)
def test_unsafe_set_version_counter2(self):
@torch.compile(backend="inductor", fullgraph=True)
def fn(w, x):
r = w.sin()
with torch.no_grad():
v = w._version
w.copy_(x)
torch._C._autograd._unsafe_set_version_counter(w, v)
return r
w1 = torch.randn(1, requires_grad=True)
x1 = torch.randn(1)
expected_r1 = w1.detach().sin()
r1 = fn(w1, x1)
r1.backward()
self.assertEqual(r1, expected_r1)
self.assertEqual(w1, x1)
self.assertEqual(w1.grad, x1.cos())
@torch.no_grad()
def test_unsafe_preserve_version_counter1(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(w, x):
x = x.sin()
with torch.autograd._unsafe_preserve_version_counter(w):
w.copy_(x + 1)
return w
w1 = torch.randn(16).fill_(0).fill_(1)
x1 = torch.randn(16)
v1 = w1._version
w2 = fn(w1, x1)
v2 = w1._version
self.assertIs(w1, w2)
self.assertEqual(w1, x1.sin() + 1)
self.assertEqual(v1, v2)
def test_unsafe_preserve_version_counter2(self):
@torch.compile(backend="inductor", fullgraph=True)
def fn(w, x):
r = w.sin()
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(w):
w.copy_(x)
return r
w1 = torch.randn(1, requires_grad=True)
x1 = torch.randn(1)
expected_r1 = w1.detach().sin()
r1 = fn(w1, x1)
r1.backward()
self.assertEqual(r1, expected_r1)
self.assertEqual(w1, x1)
self.assertEqual(w1.grad, x1.cos())
def test_module_backward_hooks_eager(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(False)
fw_cnt = CompileCounter()
bw_cnt = CompileCounter()
with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)):
m2 = torch.compile(m2, backend=fw_cnt, fullgraph=True)
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
self.assertEqual(fw_cnt.frame_count, 1)
self.assertEqual(fw_cnt.op_count, 5)
self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None
self.assertEqual(bw_cnt.op_count, 48)
def test_module_backward_hooks_aot(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(True)
m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
with compiled_autograd.enable(lambda gm: gm):
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
def test_module_backward_hooks_inductor(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(False)
m2 = torch.compile(m2, fullgraph=True)
with compiled_autograd.enable(torch.compile(fullgraph=True)):
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
def test_module_backward_hooks_multi_layers(self):
a1, inp1 = init_module_bw_hooks(True)
b1, _ = init_module_bw_hooks(True)
out1 = steps(torch.nn.Sequential(a1, b1), inp1)
a2, inp2 = init_module_bw_hooks(False)
b2, _ = init_module_bw_hooks(False)
with compiled_autograd.enable(torch.compile(fullgraph=True)):
out2 = steps(
torch.compile(torch.nn.Sequential(a2, b2), fullgraph=True), inp2
)
self.assertEqual(a1.hook_count_pre, a2.hook_count_pre)
self.assertEqual(a1.hook_count_post, a2.hook_count_post)
self.assertEqual(b1.hook_count_pre, b2.hook_count_pre)
self.assertEqual(b1.hook_count_post, b2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(a1.weight.grad, a2.weight.grad)
self.assertEqual(a1.bias.grad, a2.bias.grad)
self.assertEqual(b1.weight.grad, b2.weight.grad)
self.assertEqual(b1.bias.grad, b2.bias.grad)
# TODO(jansel): support bw hooks with graph break
def _assert_same_grad(self, a, b):
self.assertEqual(type(a), type(b))
self.assertEqual(a, b)
self.assertEqual(a.grad, b.grad)
self.assertEqual(a.requires_grad, b.requires_grad)
def test_nn_param_return1(self):
def fn(x):
p = torch.nn.Parameter(x)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
r1.sum().backward()
p2, r2 = opt(x2)
r2.sum().backward()
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return2(self):
def fn(x):
p = torch.nn.Parameter(x, requires_grad=False)
return p, x + 1
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
p2, r2 = opt(x2)
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return3(self):
def fn(x):
p = torch.nn.Parameter(x + 123)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
r1.sum().backward()
p2, r2 = opt(x2)
r2.sum().backward()
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return4(self):
def fn(x):
p = torch.nn.Parameter(x + 123, requires_grad=False)
return p, x + 1
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
p2, r2 = opt(x2)
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_fake_distributed_eager(self):
m1, inp1 = init_fake_distributed()
out1 = steps(m1, inp1)
m2, inp2 = init_fake_distributed()
fw_cnt = CompileCounter()
m2 = torch.compile(m2, backend=fw_cnt, fullgraph=True)
bw_cnt = CompileCounter()
with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=False)):
for step in range(1, 5):
out2 = m2(inp2)
out2.sum().backward()
# Graph break on TracableCreateParameter.backward
# Recompile on grad==None/grad!=None
self.assertEqual(bw_cnt.frame_count, min(step, 2) * 2)
self.assertEqual(fw_cnt.frame_count, 1)
self._assert_same_grad(m1.weight, m2.weight)
self._assert_same_grad(inp1, inp2)
self._assert_same_grad(out1, out2)
def test_fake_distributed_aot_eager(self):
m1, inp1 = init_fake_distributed()
out1 = steps(m1, inp1)
m2, inp2 = init_fake_distributed()
m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
bw_cnt = CompileCounter()
with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)):
out2 = steps(m2, inp2)
self._assert_same_grad(m1.weight, m2.weight)
self._assert_same_grad(inp1, inp2)
self._assert_same_grad(out1, out2)
# Recompile on grad==None/grad!=None
self.assertEqual(bw_cnt.frame_count, 2)
def test_fake_distributed_inductor(self):
m1, inp1 = init_fake_distributed()
out1 = steps(m1, inp1)
m2, inp2 = init_fake_distributed()
m2 = torch.compile(m2, fullgraph=True)
with compiled_autograd.enable(torch.compile(fullgraph=True)):
out2 = steps(m2, inp2)
self._assert_same_grad(m1.weight, m2.weight)
self._assert_same_grad(inp1, inp2)
self._assert_same_grad(out1, out2)
if __name__ == "__main__":
if HAS_CPU and not IS_MACOS:
run_tests(needs="filelock")