mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176 Approved by: https://github.com/bobrenjc93
335 lines
11 KiB
Python
335 lines
11 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import unittest
|
|
from collections import deque, OrderedDict
|
|
from contextlib import ContextDecorator, contextmanager, nullcontext
|
|
from copy import deepcopy
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed._composable import checkpoint
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch.utils.checkpoint import CheckpointError
|
|
|
|
|
|
class MemoryDelta(ContextDecorator):
|
|
def __init__(self, device: torch.device):
|
|
self.device: torch.device = device
|
|
self.active_memory_enter: int = 0
|
|
self.active_memory_exit: int = 0
|
|
|
|
def __enter__(self):
|
|
self.active_memory_enter = (
|
|
torch.cuda.memory_stats()["active_bytes.all.current"]
|
|
if self.device.type == "cuda"
|
|
else 0
|
|
)
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.active_memory_exit = (
|
|
torch.cuda.memory_stats()["active_bytes.all.current"]
|
|
if self.device.type == "cuda"
|
|
else 0
|
|
)
|
|
|
|
def delta(self) -> int:
|
|
return self.active_memory_exit - self.active_memory_enter
|
|
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.l1 = nn.Linear(100, 100)
|
|
self.seq = nn.Sequential(
|
|
nn.ReLU(),
|
|
nn.Linear(100, 100),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.seq(self.l1(x))
|
|
|
|
|
|
class RandomModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = nn.Parameter(torch.randn(100, 100))
|
|
|
|
def forward(self, x):
|
|
y = torch.matmul(self.p, torch.randn(100, 100, device=self.p.device))
|
|
return torch.matmul(x, y)
|
|
|
|
|
|
class MultiOutputModel(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
|
|
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
z = x @ self.w1
|
|
z = nn.functional.relu(z)
|
|
z = z @ self.w2
|
|
return z.sin(), z.cos()
|
|
|
|
|
|
class MultiInputModel(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.w = nn.Parameter(torch.randn((100, 100), device=device))
|
|
|
|
def forward(self, xs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
|
|
x, y = xs
|
|
z = x + y
|
|
z = z @ self.w
|
|
return nn.functional.relu(z)
|
|
|
|
|
|
class TestCheckpoint(TestCase):
|
|
def _get_graph_size(self, out: torch.Tensor) -> int:
|
|
q = deque([out.grad_fn])
|
|
num_functions = 0
|
|
while len(q):
|
|
fn = q.pop()
|
|
num_functions += 1
|
|
for next_fn, _ in fn.next_functions:
|
|
if next_fn:
|
|
q.append(next_fn)
|
|
|
|
return num_functions
|
|
|
|
def _test_tensor_only(
|
|
self,
|
|
net: nn.Module,
|
|
x: torch.Tensor,
|
|
) -> None:
|
|
x1 = x.clone()
|
|
x2 = x.clone()
|
|
x1.requires_grad = True
|
|
x2.requires_grad = True
|
|
|
|
net1 = net
|
|
net2 = deepcopy(net)
|
|
|
|
# no checkpoint
|
|
with MemoryDelta(x.device) as mem1:
|
|
loss1 = net1(x1).sum()
|
|
loss1.backward()
|
|
|
|
# with checkpoint
|
|
checkpoint(net2.seq)
|
|
with MemoryDelta(x.device) as mem2:
|
|
loss2 = net2(x2).sum()
|
|
loss2.backward()
|
|
|
|
if x.is_cuda:
|
|
self.assertTrue(mem2.delta() < mem1.delta())
|
|
|
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
|
self.assertEqual(p1.grad, p2.grad)
|
|
|
|
def test_tensor_only_cpu(self):
|
|
x = torch.randn(20, 100)
|
|
net = ToyModel()
|
|
self._test_tensor_only(net, x)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
|
def test_tensor_only_gpu(self):
|
|
x = torch.randn(20, 100, device="cuda:0")
|
|
net = ToyModel().to("cuda:0")
|
|
self._test_tensor_only(net, x)
|
|
|
|
def test_random_cpu(self):
|
|
x1 = torch.randn(20, 100, requires_grad=True)
|
|
x2 = x1.clone()
|
|
|
|
net1 = RandomModel()
|
|
net2 = deepcopy(net1)
|
|
|
|
cpu_rng_state = torch.get_rng_state()
|
|
net1(x1).sum().backward()
|
|
torch.set_rng_state(cpu_rng_state)
|
|
checkpoint(net2)(x2).sum().backward()
|
|
|
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
|
self.assertEqual(p1.grad, p2.grad)
|
|
|
|
def test_multi_args(self):
|
|
"""
|
|
Tests checkpoint for modules with multiple output args and hence
|
|
multiple backward function input args.
|
|
"""
|
|
device = torch.device("cpu")
|
|
net1 = nn.Sequential(
|
|
MultiOutputModel(device),
|
|
MultiInputModel(device),
|
|
MultiOutputModel(device),
|
|
MultiInputModel(device),
|
|
)
|
|
net2 = deepcopy(net1)
|
|
checkpoint(net2[0])
|
|
checkpoint(net2[2])
|
|
x1 = torch.randn(20, 100, requires_grad=True)
|
|
x2 = x1.clone()
|
|
net1(x1).sum().backward()
|
|
net2(x2).sum().backward()
|
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
|
self.assertEqual(p1.grad, p2.grad)
|
|
|
|
def test_clears_state_on_error_in_forward(self):
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self, raise_in_recomp):
|
|
super().__init__()
|
|
self.fwd_count = 0
|
|
self.raise_in_recomp = raise_in_recomp
|
|
self.a = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
if self.raise_in_recomp and self.fwd_count == 1:
|
|
raise RuntimeError("foo")
|
|
else:
|
|
if not self.raise_in_recomp:
|
|
# raise in the first forward
|
|
raise RuntimeError("foo")
|
|
self.fwd_count += 1
|
|
return self.a(x)
|
|
|
|
m = MyModel(raise_in_recomp=True)
|
|
m_seq = torch.nn.Sequential(OrderedDict({"m": m}))
|
|
checkpoint(m_seq.m)
|
|
inp = torch.randn(1, 2)
|
|
out = m_seq(inp).sum()
|
|
# Should raise in forward recomputation
|
|
with self.assertRaisesRegex(RuntimeError, "foo"):
|
|
out.backward()
|
|
|
|
# Check that _ac_generator is cleared out
|
|
self.assertEqual(None, checkpoint.state(m)._ac_generator)
|
|
|
|
m = MyModel(raise_in_recomp=False)
|
|
checkpoint(m)
|
|
inp = torch.randn(1, 2)
|
|
# Should raise in first forward
|
|
with self.assertRaises(RuntimeError):
|
|
m(inp)
|
|
|
|
self.assertEqual(None, checkpoint.state(m)._ac_generator)
|
|
|
|
def test_checkpoint_kwargs(self):
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self, raise_exp: bool, change_shape_in_recomp: bool):
|
|
super().__init__()
|
|
self.fwd_count = 0
|
|
self.raise_exp = raise_exp
|
|
self.change_shape_in_recomp = change_shape_in_recomp
|
|
self.a = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
if self.raise_exp and self.fwd_count == 0:
|
|
raise RuntimeError("foo")
|
|
if self.raise_exp and self.fwd_count == 1:
|
|
raise RuntimeError("bar")
|
|
if self.change_shape_in_recomp and self.fwd_count == 1:
|
|
x.relu_()
|
|
random_tensor = torch.randn(1, 2)
|
|
x = self.a(x + random_tensor)
|
|
self.fwd_count += 1
|
|
return x
|
|
|
|
m = MyModel(True, False)
|
|
m0, m1, m2, m3 = (deepcopy(m) for _ in range(4))
|
|
|
|
# composable checkpoint does not support use_reentrant=True
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"use_reentrant=True is not supported in composable checkpoint. "
|
|
"Please use torch.utils.checkpoint.checkpoint instead.",
|
|
):
|
|
checkpoint(m, use_reentrant=True)
|
|
|
|
# check giving an unsupported kwarg
|
|
with self.assertRaisesRegex(ValueError, "Unexpected keyword arguments: foo"):
|
|
checkpoint(m0, foo="bar")
|
|
|
|
handled_fwd_exp = False
|
|
handled_recomp_exp = False
|
|
|
|
@contextmanager
|
|
def fwd_ctx(mod: MyModel):
|
|
try:
|
|
mod.raise_exp = False
|
|
yield
|
|
finally:
|
|
nonlocal handled_fwd_exp
|
|
handled_fwd_exp = True
|
|
mod.raise_exp = True
|
|
|
|
@contextmanager
|
|
def recomp_ctx(mod: MyModel):
|
|
try:
|
|
mod.raise_exp = False
|
|
yield
|
|
finally:
|
|
nonlocal handled_recomp_exp
|
|
handled_recomp_exp = True
|
|
mod.raise_exp = True
|
|
|
|
# Test different context functions
|
|
x = torch.randn(1, 2, requires_grad=True)
|
|
checkpoint(
|
|
m1, context_fn=lambda: (partial(fwd_ctx, m1)(), partial(recomp_ctx, m1)())
|
|
)
|
|
m1(x.clone()).sum().backward()
|
|
self.assertEqual((handled_fwd_exp, handled_recomp_exp), (True, True))
|
|
|
|
checkpoint(m2, context_fn=lambda: (nullcontext(), partial(recomp_ctx, m2)()))
|
|
with self.assertRaisesRegex(RuntimeError, "foo"):
|
|
m2(x.clone())
|
|
|
|
handled_fwd_exp = False # Reset flag
|
|
checkpoint(m3, context_fn=lambda: (partial(fwd_ctx, m3)(), nullcontext()))
|
|
with self.assertRaisesRegex(RuntimeError, "bar"):
|
|
m3(x.clone()).sum().backward()
|
|
self.assertEqual(handled_fwd_exp, True)
|
|
|
|
# Test determinism check failure
|
|
m4 = MyModel(False, True)
|
|
m5 = deepcopy(m4)
|
|
# Determinism check should not throw an error,
|
|
# but autograd should throw a RuntimeError
|
|
checkpoint(m4, determinism_check="none")
|
|
with self.assertRaises(RuntimeError):
|
|
m4(x.clone()).sum().backward()
|
|
|
|
# Determinism check should throw a CheckpointError
|
|
checkpoint(m5, determinism_check="default")
|
|
with self.assertRaises(CheckpointError):
|
|
m5(x.clone()).sum().backward()
|
|
|
|
# Test preserving random state
|
|
m6 = MyModel(False, False)
|
|
m7, m8 = (deepcopy(m6) for _ in range(2))
|
|
checkpoint(m7, preserve_rng_state=False)
|
|
checkpoint(m8, preserve_rng_state=True)
|
|
|
|
for mi in (m6, m7, m8):
|
|
torch.manual_seed(42)
|
|
loss = mi(x.clone()).sum()
|
|
torch.manual_seed(41)
|
|
loss.backward()
|
|
# check that m6 and m7 have at least one different grad
|
|
self.assertNotEqual(
|
|
(p1.grad for p1 in m6.parameters()), (p2.grad for p2 in m7.parameters())
|
|
)
|
|
# check that m6 and m8 have identical grads
|
|
for p1, p2 in zip(m6.parameters(), m8.parameters()):
|
|
self.assertEqual(p1.grad, p2.grad)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|