pytorch/torch/testing/_internal/common_fsdp.py
Rohan Varma c4281cc92d Prototype checkpoint_wrapper (#69955)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955

Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.

Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.

Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D33107276

fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb
2021-12-16 09:59:19 -08:00

509 lines
18 KiB
Python

# Owner(s): ["oncall: distributed"]
from contextlib import suppress
from enum import Enum
import os
import sys
from unittest import mock
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._fsdp import FullyShardedDataParallel, CPUOffload
from torch.distributed._fsdp.fully_sharded_data_parallel import (
TrainingState_,
)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
get_cycles_per_ms,
)
class FSDPInitMode(Enum):
# Move model to CUDA before wrap
CUDA_BEFORE = 1
# Move model to CUDA after wrap
CUDA_AFTER = 2
# Don't move model to CUDA at all.
CUDA_NEVER = 3
# get full params of a model recursively. Note that if CPU offloading, it will
# also automatically move the parameters to GPU, due to _rebuild_full_params
# call.
def get_full_params(model, recurse=True):
if recurse:
# get all params for any nested FSDP instances.
for module in model.modules():
if isinstance(module, FullyShardedDataParallel):
get_full_params(module, recurse=False)
else:
torch.cuda.synchronize()
model._rebuild_full_params()
if model.module.flat_param is not None:
model.module._unflatten_params()
def _maybe_cuda(model, move_to_cuda):
return model.cuda() if move_to_cuda else model
def _maybe_wrap_fsdp(model, wrap_fsdp, *args, **kwargs):
return (
model if not wrap_fsdp
else FullyShardedDataParallel(model, *args, **kwargs)
)
class DummyProcessGroup:
def __init__(self, rank: int, size: int):
self._rank = rank
self._size = size
def rank(self) -> int:
return self._rank
def size(self) -> int:
return self._size
class TransformerWithSharedParams(nn.Module):
def __init__(
self, group, *args, d_vocab=23, d_model=16, add_bn=True,
fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs
):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
torch.manual_seed(0) # keep everything deterministic
assert (
d_vocab >= 12
), "dim of vocab should be larger than 12, as we use torch.arange(12) as input"
self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer(
d_model=d_model,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=8,
dropout=0.1,
)
self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight
self.register_buffer(
"vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
)
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long)) # type: ignore[arg-type]
self.bs = 2
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
self = _maybe_cuda(self, move_to_cuda)
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
src = torch.arange(12, device=device).view(6, self.bs) # T x B
tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
return (src, tgt)
def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids)
src = src + self.vocab_bias + self.long_buffer.type_as(src) # type: ignore[operator]
tgt = self.embed_tokens(tgt_ids)
tgt = self.bn(tgt)
x = self.transformer(src, tgt)
return self.output_proj(x)
def get_loss(self, input, output):
_, tgt = input
return nn.functional.cross_entropy(
output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
)
def run_backward(self, loss):
loss.backward()
class NestedWrappedModule(nn.Module):
def __init__(self, group, wrap_fsdp, *args, wrap_everything=False, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
def _maybe_wrap(layer):
if wrap_fsdp:
return FullyShardedDataParallel(layer, group, *args, **kwargs)
return layer
torch.manual_seed(0) # keep everything deterministic
if wrap_everything:
self.module = nn.Sequential(
_maybe_wrap(_maybe_cuda(nn.Linear(8, 4), move_to_cuda)),
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
_maybe_wrap(_maybe_cuda(nn.Linear(4, 8), move_to_cuda)),
)
else:
self.module = nn.Sequential(
_maybe_cuda(nn.Linear(8, 4), move_to_cuda),
_maybe_wrap(
nn.Sequential(
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
_maybe_cuda(nn.Linear(16, 16), move_to_cuda),
),
),
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
_maybe_cuda(nn.Linear(4, 8), move_to_cuda),
)
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
return (torch.rand(4, 8, device=device),)
def forward(self, x):
return self.module(x)
def get_loss(self, input, output):
loss = output.sum()
return loss
def run_backward(self, loss):
loss.backward()
class ModuleWithDelay(nn.Module):
def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0):
super().__init__()
self.delay_after_loss_ms = delay_after_loss_ms
self.delay_before_reduction_ms = delay_before_reduction_ms
self.module = module
def get_input(self, device):
return self.module.get_input(device)
def forward(self, x):
return self.module(x)
def get_loss(self, input, output):
loss = self.module.get_loss(input, output)
if self.delay_after_loss_ms > 0:
torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
return loss
def run_backward(self, loss):
orig_reduce_scatter = torch.distributed._reduce_scatter_base
def _delayed_reduce_scatter(*args, **kwargs):
if self.delay_before_reduction_ms > 0:
torch.cuda._sleep(
int(self.delay_before_reduction_ms * get_cycles_per_ms())
)
return orig_reduce_scatter(*args, **kwargs)
with mock.patch(
"torch.distributed._reduce_scatter_base", _delayed_reduce_scatter
):
self.module.run_backward(loss)
class NestedWrappedModuleWithDelay(ModuleWithDelay):
def __init__(
self,
group,
wrap_fsdp,
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
cpu_offload=None,
**kwargs
):
super().__init__(
NestedWrappedModule(
group,
wrap_fsdp,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
),
**kwargs
)
class DummyDDP(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class MixtureOfExperts(NestedWrappedModule):
def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, **kwargs):
super().__init__(group, wrap_fsdp)
self.group = group
self.delay_before_free_ms = delay_before_free_ms
self.wrap_fsdp = wrap_fsdp
self.move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
d_expert = 23
d_shared = 12
d_input = 8
expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True # type: ignore[attr-defined]
# everything else is shared
torch.manual_seed(0)
shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)
if wrap_fsdp:
# we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group(
[group.rank()]
) # world size 1 means no shard
expert = FullyShardedDataParallel(expert, expert_group, **kwargs) # type: ignore[assignment]
shared = FullyShardedDataParallel(shared, group, **kwargs) # type: ignore[assignment]
self.module = nn.Sequential(
_maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
shared,
expert,
_maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda)
)
def forward(self, x):
if self.delay_before_free_ms > 0:
expert = self.module[2]
if isinstance(expert, FullyShardedDataParallel):
orig_free_full_params = self.module[2]._free_full_params
def _free_full_params_with_delay(*args):
torch.cuda._sleep(
int(self.delay_before_free_ms * get_cycles_per_ms())
)
return orig_free_full_params(*args)
assert hasattr(
expert, "_free_full_params"
), "expert FSDP module should has _free_full_params attribute."
with mock.patch.object(
expert, "_free_full_params", _free_full_params_with_delay
):
return self.module(x)
return self.module(x)
def run_backward(self, loss):
loss.backward()
# manually reduce gradients if not wrapped in FullyShardedDataParallel
if not self.wrap_fsdp:
with torch.no_grad():
for p in self.parameters():
if hasattr(p, "expert"):
continue # these params don't need grad reduction
p.grad.div_(self.world_size)
torch.distributed.all_reduce(p.grad, group=self.group)
class FSDPTest(MultiProcessTestCase):
def setUp(self):
super(FSDPTest, self).setUp()
self._spawn_processes()
@property
def world_size(self):
return torch.cuda.device_count() if torch.cuda.is_available() else 4
@property
def init_method(self):
return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name)
def _check_cpu_offload(self, fsdp_model, cpu_offload):
self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
@classmethod
def _run(cls, rank, test_name, file_name, pipe):
self = cls(test_name)
self.rank = rank
self.file_name = file_name
print(f"dist init r={self.rank}, world={self.world_size}")
# Specify gloo backend to make 'init_process_group()' succeed,
# Actual tests will be skipped if there is no enough GPUs.
backend = os.environ.get("BACKEND", None)
if backend is None:
backend = "nccl" if torch.cuda.is_available() else "gloo"
try:
dist.init_process_group(
init_method=self.init_method,
backend=backend,
world_size=int(self.world_size),
rank=self.rank,
)
except RuntimeError as e:
if "recompile" in e.args[0]:
sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
raise
if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(self.rank % torch.cuda.device_count())
# Execute barrier prior to running test to ensure that every process
# has finished initialization and that the following test
# immediately exiting due to a skip doesn't cause flakiness.
dist.barrier()
self.run_test(test_name, pipe)
dist.barrier()
dist.destroy_process_group()
sys.exit(0)
def _train_for_several_steps(self, model, num_steps, autocast, lr=0.01, fsdp_cpu_offload=None):
cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
model_device = next(model.parameters()).device
# use SGD with momentum instead of Adam, since Adam is scale invariant
# and this makes it bad for tests
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
# Inputs always cuda regardless of cpu offloading, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
# Post-forward, if CPU offloading model param should be on CPU.
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
for p in model.parameters():
# Params should always be on CPU, even if
# p._is_sharded=False
self.assertEqual(p.device, torch.device("cpu"))
loss = model.module.get_loss(input, output).to(model_device)
assert (
loss.dtype == torch.float32
), "loss data type should be float32, as the original \
parameter data type is float32."
model.module.run_backward(loss)
# Post-backward, if CPU offloading model params should be on CPU.
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
for p in model.parameters():
# Params should always be on CPU, even if
# p._is_sharded=False
self.assertEqual(p.device, torch.device("cpu"))
optim.step()
if isinstance(model, FullyShardedDataParallel):
model._assert_state(TrainingState_.IDLE)
return loss.detach()
def _test_identical_outputs(
self,
model_init_fn,
*args,
ref_ddp_fn=None,
num_steps=2,
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
lr=0.01,
cpu_offload=CPUOffload(),
**kwargs
):
group = dist.distributed_c10d._get_default_group()
rank = group.rank()
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrap_fsdp=False).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank
)
else:
model = ref_ddp_fn(model)
ref_loss = self._train_for_several_steps(
model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload
)
ref_full_params = list(model.parameters())
# Confirm we get the same behavior using FullyShardedDataParallel.
try:
model = model_init_fn(group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload)
except Exception as e:
raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")
cpu_offload = cpu_offload or CPUOffload() # disabled if not specified.
model = FullyShardedDataParallel(model, cpu_offload=cpu_offload)
# Call model.cuda() after init FSDP if specified.
if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
model = model.cuda()
# Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
# expect FSDP code to raise error that we check below, in the case of
# offload params.
if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
for p in model.parameters():
# Should be on CPU regardless of if param is sharded.
self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")
only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
ctx = (
self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
if only_check_err else suppress()
)
with ctx:
shard_loss = self._train_for_several_steps(
model, num_steps, autocast=False, lr=lr,
fsdp_cpu_offload=cpu_offload,
)
# We only check for errors in the case we have the following setup:
# model = FSDP(model, cpu_offload=True)
# model = model.cuda()
# so skip the rest of this logic.
if only_check_err:
return
# If CPU offload, next call will change model params to GPU. Sanity
# check that params are on CPU before.
if cpu_offload.offload_params:
device_set = {p.device for p in model.parameters()}
self.assertEqual(
{torch.device("cpu")},
device_set,
f"Got device set {device_set}"
)
get_full_params(model)
shard_full_params = list(model.parameters())
if cpu_offload.offload_params:
shard_loss = shard_loss.cuda()
torch.testing.assert_allclose(ref_loss, shard_loss)
self.assertEqual(
ref_full_params,
shard_full_params,
exact_device=True,
msg="FullyShardedDataParallel didn't match PyTorch DDP",
)
def _get_wrapped_model(
self, group, cuda_first=False, **model_kwargs
) -> FullyShardedDataParallel:
if cuda_first:
model = FullyShardedDataParallel(
TransformerWithSharedParams(group, **model_kwargs).cuda(), group
)
else:
model = FullyShardedDataParallel(
TransformerWithSharedParams(group, **model_kwargs), group
).cuda()
return model