mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Removes reentrant support for the composable checkpoint, as non-reentrant is the recommended approach and we should use this when rolling out composable checkpoint API. Also removes the standalone implementation for non-reentrant and instead uses the generator from below diff to reuse the original implemenetation. Differential Revision: [D47451375](https://our.internmc.facebook.com/intern/diff/D47451375/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105176 Approved by: https://github.com/awgu, https://github.com/fegin
225 lines
6.6 KiB
Python
225 lines
6.6 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import unittest
|
|
from collections import deque, OrderedDict
|
|
from contextlib import ContextDecorator
|
|
from copy import deepcopy
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed._composable import checkpoint
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class MemoryDelta(ContextDecorator):
|
|
def __init__(self, device: torch.device):
|
|
self.device: torch.device = device
|
|
self.active_memory_enter: int = 0
|
|
self.active_memory_exit: int = 0
|
|
|
|
def __enter__(self):
|
|
self.active_memory_enter = (
|
|
torch.cuda.memory_stats()["active_bytes.all.current"]
|
|
if self.device.type == "cuda"
|
|
else 0
|
|
)
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.active_memory_exit = (
|
|
torch.cuda.memory_stats()["active_bytes.all.current"]
|
|
if self.device.type == "cuda"
|
|
else 0
|
|
)
|
|
|
|
def delta(self) -> int:
|
|
return self.active_memory_exit - self.active_memory_enter
|
|
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = nn.Linear(100, 100)
|
|
self.seq = nn.Sequential(
|
|
nn.ReLU(),
|
|
nn.Linear(100, 100),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.seq(self.l1(x))
|
|
|
|
|
|
class RandomModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.p = nn.Parameter(torch.randn(100, 100))
|
|
|
|
def forward(self, x):
|
|
y = torch.matmul(self.p, torch.randn(100, 100, device=self.p.device))
|
|
return torch.matmul(x, y)
|
|
|
|
|
|
class MultiOutputModel(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
|
|
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
z = x @ self.w1
|
|
z = nn.functional.relu(z)
|
|
z = z @ self.w2
|
|
return z.sin(), z.cos()
|
|
|
|
|
|
class MultiInputModel(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.w = nn.Parameter(torch.randn((100, 100), device=device))
|
|
|
|
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
|
|
x, y = xs
|
|
z = x + y
|
|
z = z @ self.w
|
|
return nn.functional.relu(z)
|
|
|
|
|
|
class TestCheckpoint(TestCase):
|
|
def _get_graph_size(self, out: torch.Tensor) -> int:
|
|
q = deque([out.grad_fn])
|
|
num_functions = 0
|
|
while len(q):
|
|
fn = q.pop()
|
|
num_functions += 1
|
|
for next_fn, _ in fn.next_functions:
|
|
if next_fn:
|
|
q.append(next_fn)
|
|
|
|
return num_functions
|
|
|
|
def _test_tensor_only(
|
|
self,
|
|
net: nn.Module,
|
|
x: torch.Tensor,
|
|
) -> None:
|
|
x1 = x.clone()
|
|
x2 = x.clone()
|
|
x1.requires_grad = True
|
|
x2.requires_grad = True
|
|
|
|
net1 = net
|
|
net2 = deepcopy(net)
|
|
|
|
# no checkpoint
|
|
with MemoryDelta(x.device) as mem1:
|
|
loss1 = net1(x1).sum()
|
|
graph_size1 = self._get_graph_size(loss1)
|
|
loss1.backward()
|
|
|
|
# with checkpoint
|
|
checkpoint(net2.seq)
|
|
with MemoryDelta(x.device) as mem2:
|
|
loss2 = net2(x2).sum()
|
|
loss2.backward()
|
|
|
|
if x.is_cuda:
|
|
self.assertTrue(mem2.delta() < mem1.delta())
|
|
|
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
|
self.assertEqual(p1.grad, p2.grad)
|
|
|
|
def test_tensor_only_cpu(self):
|
|
x = torch.randn(20, 100)
|
|
net = ToyModel()
|
|
self._test_tensor_only(net, x)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
|
def test_tensor_only_gpu(self):
|
|
x = torch.randn(20, 100, device="cuda:0")
|
|
net = ToyModel().to("cuda:0")
|
|
self._test_tensor_only(net, x)
|
|
|
|
def test_random_cpu(self):
|
|
x1 = torch.randn(20, 100, requires_grad=True)
|
|
x2 = x1.clone()
|
|
|
|
net1 = RandomModel()
|
|
net2 = deepcopy(net1)
|
|
|
|
cpu_rng_state = torch.get_rng_state()
|
|
net1(x1).sum().backward()
|
|
torch.set_rng_state(cpu_rng_state)
|
|
checkpoint(net2)(x2).sum().backward()
|
|
|
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
|
self.assertEqual(p1.grad, p2.grad)
|
|
|
|
def test_multi_args(self):
|
|
"""
|
|
Tests checkpoint for modules with multiple output args and hence
|
|
multiple backward function input args.
|
|
"""
|
|
device = torch.device("cpu")
|
|
net1 = nn.Sequential(
|
|
MultiOutputModel(device),
|
|
MultiInputModel(device),
|
|
MultiOutputModel(device),
|
|
MultiInputModel(device),
|
|
)
|
|
net2 = deepcopy(net1)
|
|
checkpoint(net2[0])
|
|
checkpoint(net2[2])
|
|
x1 = torch.randn(20, 100, requires_grad=True)
|
|
x2 = x1.clone()
|
|
net1(x1).sum().backward()
|
|
net2(x2).sum().backward()
|
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
|
self.assertEqual(p1.grad, p2.grad)
|
|
|
|
def test_clears_state_on_error_in_forward(self):
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self, raise_in_recomp):
|
|
super().__init__()
|
|
self.fwd_count = 0
|
|
self.raise_in_recomp = raise_in_recomp
|
|
self.a = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
if self.raise_in_recomp and self.fwd_count == 1:
|
|
raise RuntimeError("foo")
|
|
else:
|
|
if not self.raise_in_recomp:
|
|
# raise in the first forward
|
|
raise RuntimeError("foo")
|
|
self.fwd_count += 1
|
|
return self.a(x)
|
|
|
|
m = MyModel(raise_in_recomp=True)
|
|
m_seq = torch.nn.Sequential(OrderedDict({"m": m}))
|
|
checkpoint(m_seq.m)
|
|
inp = torch.randn(1, 2)
|
|
out = m_seq(inp).sum()
|
|
# Should raise in forward recomputation
|
|
with self.assertRaisesRegex(RuntimeError, "foo"):
|
|
out.backward()
|
|
|
|
# Check that _ac_generator is cleared out
|
|
self.assertEqual(None, checkpoint.state(m)._ac_generator)
|
|
|
|
m = MyModel(raise_in_recomp=False)
|
|
checkpoint(m)
|
|
inp = torch.randn(1, 2)
|
|
# Should raise in first forward
|
|
with self.assertRaises(RuntimeError):
|
|
m(inp)
|
|
|
|
self.assertEqual(None, checkpoint.state(m)._ac_generator)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|