mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75423 Make summon_full_params a static method. We still retain the old summon_full_params as a private API `_summon_full_params` and there are a couple of callsites to this within only FSDP file, but we can remove these as well. Differential Revision: [D35444539](https://our.internmc.facebook.com/intern/diff/D35444539/) Approved by: https://github.com/awgu
696 lines
25 KiB
Python
696 lines
25 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
from contextlib import suppress
|
|
from copy import deepcopy
|
|
from enum import Enum
|
|
from math import inf
|
|
from typing import Union
|
|
from unittest import mock
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_
|
|
from torch.testing._internal.common_distributed import (
|
|
TEST_SKIPS,
|
|
MultiProcessTestCase,
|
|
)
|
|
from torch.distributed.fsdp.wrap import wrap
|
|
from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms
|
|
|
|
|
|
class FSDPInitMode(Enum):
|
|
# Move model to CUDA before wrap
|
|
CUDA_BEFORE = 1
|
|
# Move model to CUDA after wrap
|
|
CUDA_AFTER = 2
|
|
# Don't move model to CUDA at all.
|
|
CUDA_NEVER = 3
|
|
|
|
def _get_full_detached_param(fsdp_model: FullyShardedDataParallel):
|
|
with FullyShardedDataParallel.summon_full_params(fsdp_model):
|
|
params = list(p.clone().detach_() for p in fsdp_model.parameters())
|
|
|
|
return params
|
|
|
|
def _zero_model(fsdp_model: FullyShardedDataParallel):
|
|
with FullyShardedDataParallel.summon_full_params(fsdp_model):
|
|
for param in fsdp_model.parameters():
|
|
with torch.no_grad():
|
|
param.zero_()
|
|
|
|
def _get_state_dict(model, cpu_offload=False, half=False):
|
|
if not cpu_offload:
|
|
model = model.cuda()
|
|
if half:
|
|
model.half()
|
|
|
|
return model.state_dict()
|
|
|
|
def subtest_name(test_name_mapping, *args):
|
|
return '_'.join(
|
|
[test_name_mapping[str(s)] if s is not None else "none" for s in args]
|
|
)
|
|
|
|
# get full params of a model recursively. Note that if CPU offloading, it will
|
|
# also automatically move the parameters to GPU, due to _rebuild_full_params
|
|
# call.
|
|
def get_full_params(model, recurse=True):
|
|
with FullyShardedDataParallel.summon_full_params(model, recurse=recurse):
|
|
return deepcopy(list(model.parameters()))
|
|
|
|
def _maybe_cuda(model, move_to_cuda):
|
|
return model.cuda() if move_to_cuda else model
|
|
|
|
def _maybe_wrap_fsdp(model, wrap_fsdp, *args, **kwargs):
|
|
return (
|
|
model if not wrap_fsdp
|
|
else FullyShardedDataParallel(model, *args, **kwargs)
|
|
)
|
|
|
|
class DummyProcessGroup:
|
|
def __init__(self, rank: int, size: int):
|
|
self._rank = rank
|
|
self._size = size
|
|
|
|
def rank(self) -> int:
|
|
return self._rank
|
|
|
|
def size(self) -> int:
|
|
return self._size
|
|
|
|
class DeterministicModel(torch.nn.Module):
|
|
def __init__(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)):
|
|
super().__init__()
|
|
# keep everything deterministic for model initialization
|
|
torch.manual_seed(0)
|
|
self.inner: Union[torch.nn.Linear, FullyShardedDataParallel] = \
|
|
torch.nn.Linear(2, 2).cuda()
|
|
if wrap_fsdp:
|
|
self.inner = FullyShardedDataParallel(self.inner, cpu_offload=cpu_offload)
|
|
self.outer = torch.nn.Linear(2, 2).cuda()
|
|
|
|
def forward(self, x):
|
|
y = self.inner(x)
|
|
return self.outer(y)
|
|
|
|
class TransformerWithSharedParams(nn.Module):
|
|
def __init__(
|
|
self, group, *args, d_vocab=23, d_model=16, add_bn=True,
|
|
fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs
|
|
):
|
|
super().__init__()
|
|
self.rank = group.rank()
|
|
self.world_size = group.size()
|
|
torch.manual_seed(0) # keep everything deterministic
|
|
assert (
|
|
d_vocab >= 12
|
|
), "dim of vocab should be larger than 12, as we use torch.arange(12) as input"
|
|
|
|
self.embed_tokens = nn.Embedding(d_vocab, d_model)
|
|
self.transformer = nn.Transformer(
|
|
d_model=d_model,
|
|
num_encoder_layers=2,
|
|
num_decoder_layers=2,
|
|
dim_feedforward=8,
|
|
dropout=0.1,
|
|
)
|
|
self.output_proj = nn.Linear(d_model, d_vocab)
|
|
|
|
# share the embedding and output projection weights
|
|
self.output_proj.weight = self.embed_tokens.weight
|
|
self.register_buffer(
|
|
"vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
|
|
)
|
|
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long)) # type: ignore[arg-type]
|
|
|
|
self.bs = 2
|
|
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
|
|
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
|
self = _maybe_cuda(self, move_to_cuda)
|
|
|
|
def get_input(self, device):
|
|
torch.manual_seed(1 + self.rank) # keep everything deterministic
|
|
src = torch.arange(12, device=device).view(6, self.bs) # T x B
|
|
tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
|
|
return (src, tgt)
|
|
|
|
def forward(self, src_ids, tgt_ids):
|
|
src = self.embed_tokens(src_ids)
|
|
src = src + self.vocab_bias + self.long_buffer.type_as(src) # type: ignore[operator]
|
|
tgt = self.embed_tokens(tgt_ids)
|
|
tgt = self.bn(tgt)
|
|
x = self.transformer(src, tgt)
|
|
return self.output_proj(x)
|
|
|
|
def get_loss(self, input, output):
|
|
_, tgt = input
|
|
return nn.functional.cross_entropy(
|
|
output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
|
|
)
|
|
|
|
def run_backward(self, loss):
|
|
loss.backward()
|
|
|
|
|
|
class NestedWrappedModule(nn.Module):
|
|
def __init__(self, group, wrap_fsdp, *args, wrap_everything=False, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs):
|
|
super().__init__()
|
|
self.rank = group.rank()
|
|
self.world_size = group.size()
|
|
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
|
|
|
def _maybe_wrap(layer):
|
|
if wrap_fsdp:
|
|
return FullyShardedDataParallel(layer, group, *args, **kwargs)
|
|
return layer
|
|
|
|
torch.manual_seed(0) # keep everything deterministic
|
|
|
|
if wrap_everything:
|
|
self.module = nn.Sequential(
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(8, 4), move_to_cuda)),
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(4, 8), move_to_cuda)),
|
|
)
|
|
else:
|
|
self.module = nn.Sequential(
|
|
_maybe_cuda(nn.Linear(8, 4), move_to_cuda),
|
|
_maybe_wrap(
|
|
nn.Sequential(
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
|
|
_maybe_cuda(nn.Linear(16, 16), move_to_cuda),
|
|
),
|
|
),
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
|
|
_maybe_cuda(nn.Linear(4, 8), move_to_cuda),
|
|
)
|
|
|
|
def get_input(self, device):
|
|
torch.manual_seed(1 + self.rank) # keep everything deterministic
|
|
return (torch.rand(4, 8, device=device),)
|
|
|
|
def forward(self, x):
|
|
return self.module(x)
|
|
|
|
def get_loss(self, input, output):
|
|
loss = output.sum()
|
|
return loss
|
|
|
|
def run_backward(self, loss):
|
|
loss.backward()
|
|
|
|
|
|
class ModuleWithDelay(nn.Module):
|
|
def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0):
|
|
super().__init__()
|
|
self.delay_after_loss_ms = delay_after_loss_ms
|
|
self.delay_before_reduction_ms = delay_before_reduction_ms
|
|
self.module = module
|
|
|
|
def get_input(self, device):
|
|
return self.module.get_input(device)
|
|
|
|
def forward(self, x):
|
|
return self.module(x)
|
|
|
|
def get_loss(self, input, output):
|
|
loss = self.module.get_loss(input, output)
|
|
if self.delay_after_loss_ms > 0:
|
|
torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
|
|
return loss
|
|
|
|
def run_backward(self, loss):
|
|
orig_reduce_scatter = torch.distributed._reduce_scatter_base
|
|
|
|
def _delayed_reduce_scatter(*args, **kwargs):
|
|
if self.delay_before_reduction_ms > 0:
|
|
torch.cuda._sleep(
|
|
int(self.delay_before_reduction_ms * get_cycles_per_ms())
|
|
)
|
|
return orig_reduce_scatter(*args, **kwargs)
|
|
|
|
with mock.patch(
|
|
"torch.distributed._reduce_scatter_base", _delayed_reduce_scatter
|
|
):
|
|
self.module.run_backward(loss)
|
|
|
|
|
|
class NestedWrappedModuleWithDelay(ModuleWithDelay):
|
|
def __init__(
|
|
self,
|
|
group,
|
|
wrap_fsdp,
|
|
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
|
|
cpu_offload=None,
|
|
backward_prefetch=None,
|
|
sharding_strategy=None,
|
|
mixed_precision=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
NestedWrappedModule(
|
|
group,
|
|
wrap_fsdp,
|
|
fsdp_init_mode=fsdp_init_mode,
|
|
cpu_offload=cpu_offload,
|
|
backward_prefetch=backward_prefetch,
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
),
|
|
**kwargs
|
|
)
|
|
|
|
|
|
class DummyDDP(nn.Module):
|
|
def __init__(self, module):
|
|
super().__init__()
|
|
self.module = module
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.module(*args, **kwargs)
|
|
|
|
|
|
class MixtureOfExperts(NestedWrappedModule):
|
|
def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, **kwargs):
|
|
super().__init__(group, wrap_fsdp)
|
|
self.group = group
|
|
self.delay_before_free_ms = delay_before_free_ms
|
|
self.wrap_fsdp = wrap_fsdp
|
|
self.move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
|
# "expert" params are different on each rank
|
|
torch.manual_seed(42 + group.rank())
|
|
d_expert = 23
|
|
d_shared = 12
|
|
d_input = 8
|
|
expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)
|
|
|
|
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
|
|
for p in expert.parameters():
|
|
p.expert = True # type: ignore[attr-defined]
|
|
|
|
# everything else is shared
|
|
torch.manual_seed(0)
|
|
|
|
shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)
|
|
|
|
if wrap_fsdp:
|
|
# we create a process group of size 1 for the expert params
|
|
expert_group = torch.distributed.new_group(
|
|
[group.rank()]
|
|
) # world size 1 means no shard
|
|
expert = FullyShardedDataParallel(expert, expert_group, **kwargs) # type: ignore[assignment]
|
|
|
|
shared = FullyShardedDataParallel(shared, group, **kwargs) # type: ignore[assignment]
|
|
|
|
self.module = nn.Sequential(
|
|
_maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
|
|
shared,
|
|
expert,
|
|
_maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda)
|
|
)
|
|
|
|
def forward(self, x):
|
|
if self.delay_before_free_ms > 0:
|
|
expert = self.module[2]
|
|
if isinstance(expert, FullyShardedDataParallel):
|
|
orig_free_full_params = self.module[2]._free_full_params
|
|
|
|
def _free_full_params_with_delay(*args):
|
|
torch.cuda._sleep(
|
|
int(self.delay_before_free_ms * get_cycles_per_ms())
|
|
)
|
|
return orig_free_full_params(*args)
|
|
|
|
assert hasattr(
|
|
expert, "_free_full_params"
|
|
), "expert FSDP module should has _free_full_params attribute."
|
|
with mock.patch.object(
|
|
expert, "_free_full_params", _free_full_params_with_delay
|
|
):
|
|
return self.module(x)
|
|
|
|
return self.module(x)
|
|
|
|
def run_backward(self, loss):
|
|
loss.backward()
|
|
|
|
# manually reduce gradients if not wrapped in FullyShardedDataParallel
|
|
if not self.wrap_fsdp:
|
|
with torch.no_grad():
|
|
for p in self.parameters():
|
|
if hasattr(p, "expert"):
|
|
continue # these params don't need grad reduction
|
|
p.grad.div_(self.world_size)
|
|
torch.distributed.all_reduce(p.grad, group=self.group)
|
|
|
|
|
|
class FSDPTest(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super(FSDPTest, self).setUp()
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def world_size(self):
|
|
return torch.cuda.device_count() if torch.cuda.is_available() else 4
|
|
|
|
@property
|
|
def init_method(self):
|
|
return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name)
|
|
|
|
def _check_cpu_offload(self, fsdp_model, cpu_offload):
|
|
self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
|
|
|
|
def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
|
|
self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
|
|
|
|
@classmethod
|
|
def _run(cls, rank, test_name, file_name, pipe):
|
|
self = cls(test_name)
|
|
self.rank = rank
|
|
self.file_name = file_name
|
|
|
|
print(f"dist init r={self.rank}, world={self.world_size}")
|
|
|
|
# Specify gloo backend to make 'init_process_group()' succeed,
|
|
# Actual tests will be skipped if there is no enough GPUs.
|
|
|
|
backend = os.environ.get("BACKEND", None)
|
|
if backend is None:
|
|
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
|
|
|
try:
|
|
dist.init_process_group(
|
|
init_method=self.init_method,
|
|
backend=backend,
|
|
world_size=int(self.world_size),
|
|
rank=self.rank,
|
|
)
|
|
except RuntimeError as e:
|
|
if "recompile" in e.args[0]:
|
|
sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
|
|
|
|
raise
|
|
|
|
if torch.cuda.is_available() and torch.cuda.device_count():
|
|
torch.cuda.set_device(self.rank % torch.cuda.device_count())
|
|
|
|
# Execute barrier prior to running test to ensure that every process
|
|
# has finished initialization and that the following test
|
|
# immediately exiting due to a skip doesn't cause flakiness.
|
|
dist.barrier()
|
|
|
|
self.run_test(test_name, pipe)
|
|
|
|
dist.barrier()
|
|
|
|
dist.destroy_process_group()
|
|
sys.exit(0)
|
|
|
|
def _train_for_several_steps(
|
|
self,
|
|
model,
|
|
num_steps,
|
|
autocast,
|
|
lr=0.01,
|
|
fsdp_cpu_offload=None,
|
|
clip_norm=0.3,
|
|
norm_type=None,
|
|
save_model=False,
|
|
mixed_precision=None
|
|
):
|
|
cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
|
|
|
|
model_device = next(model.parameters()).device
|
|
# use SGD with momentum instead of Adam, since Adam is scale invariant
|
|
# and this makes it bad for tests
|
|
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
|
|
for _ in range(num_steps):
|
|
optim.zero_grad()
|
|
with torch.cuda.amp.autocast(enabled=autocast):
|
|
# Inputs always cuda regardless of cpu offloading, or model.device
|
|
input = model.module.get_input(torch.device("cuda"))
|
|
if mixed_precision and not isinstance(model, FullyShardedDataParallel):
|
|
if isinstance(input, torch.Tensor):
|
|
input = input.half()
|
|
else:
|
|
input = tuple(x.half() for x in input)
|
|
output = model(*input)
|
|
# Post-forward, if CPU offloading model param should be on CPU.
|
|
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
|
|
for p in model.parameters():
|
|
# Params should always be on CPU, even if
|
|
# p._is_sharded=False
|
|
self.assertEqual(p.device, torch.device("cpu"))
|
|
|
|
loss = model.module.get_loss(input, output).to(model_device)
|
|
if not mixed_precision:
|
|
assert (
|
|
loss.dtype == torch.float32
|
|
), "loss data type should be float32, as the original \
|
|
parameter data type is float32."
|
|
else:
|
|
# FSDP loss is fp16, DDP AMP loss is fp32
|
|
if isinstance(model, FullyShardedDataParallel):
|
|
self.assertEqual(loss.dtype, mixed_precision.param_dtype)
|
|
else:
|
|
self.assertEqual(loss.dtype, torch.float32)
|
|
model.module.run_backward(loss)
|
|
if norm_type is not None:
|
|
if isinstance(model, FullyShardedDataParallel):
|
|
model.clip_grad_norm_(clip_norm, norm_type)
|
|
total_norm_after_clip = _collect_total_grad_norm_fsdp(
|
|
model, norm_type, self.rank
|
|
)
|
|
else:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
|
|
total_norm_after_clip = _collect_total_grad_norm_local(
|
|
model, norm_type
|
|
)
|
|
self.assertTrue(total_norm_after_clip <= clip_norm)
|
|
# Post-backward, if CPU offloading model params should be on CPU.
|
|
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
|
|
for p in model.parameters():
|
|
# Params should always be on CPU, even if
|
|
# p._is_sharded=False
|
|
self.assertEqual(p.device, torch.device("cpu"))
|
|
optim.step()
|
|
# if save_model, simulate save + load.
|
|
|
|
if save_model:
|
|
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
|
|
# Zero params, if save/load state_dict did not work properly, this
|
|
# would break the parity test with DDP.
|
|
_zero_model(model)
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
if isinstance(model, FullyShardedDataParallel):
|
|
model._assert_state(TrainingState_.IDLE)
|
|
return loss.detach()
|
|
|
|
def _test_identical_outputs(
|
|
self,
|
|
model_init_fn,
|
|
*args,
|
|
ref_ddp_fn=None,
|
|
num_steps=2,
|
|
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
|
|
lr=0.01,
|
|
cpu_offload=CPUOffload(),
|
|
backward_prefetch=None,
|
|
sharding_strategy=None,
|
|
mixed_precision=None,
|
|
save_model=True,
|
|
clip_norm=0.3,
|
|
norm_type=None,
|
|
**kwargs
|
|
):
|
|
group = dist.distributed_c10d._get_default_group()
|
|
rank = group.rank()
|
|
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
|
|
model = model_init_fn(group=group, wrap_fsdp=False).cuda()
|
|
if ref_ddp_fn is None:
|
|
model = nn.parallel.DistributedDataParallel(
|
|
model, device_ids=[rank], output_device=rank
|
|
)
|
|
else:
|
|
model = ref_ddp_fn(model)
|
|
|
|
# DDP training
|
|
ref_loss = self._train_for_several_steps(
|
|
model, num_steps, autocast=mixed_precision is not None, lr=lr,
|
|
fsdp_cpu_offload=cpu_offload, mixed_precision=mixed_precision,
|
|
)
|
|
ref_full_params = list(model.parameters())
|
|
|
|
# Confirm we get the same behavior using FullyShardedDataParallel.
|
|
try:
|
|
model = model_init_fn(
|
|
group=group,
|
|
wrap_fsdp=True,
|
|
fsdp_init_mode=fsdp_init_mode,
|
|
cpu_offload=cpu_offload,
|
|
backward_prefetch=backward_prefetch,
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")
|
|
|
|
cpu_offload = cpu_offload or CPUOffload() # disabled if not specified.
|
|
model = FullyShardedDataParallel(
|
|
model,
|
|
cpu_offload=cpu_offload,
|
|
backward_prefetch=backward_prefetch,
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
)
|
|
# Call model.cuda() after init FSDP if specified.
|
|
if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
|
|
model = model.cuda()
|
|
|
|
# Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
|
|
# expect FSDP code to raise error that we check below, in the case of
|
|
# offload params.
|
|
if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
|
|
for p in model.parameters():
|
|
# Should be on CPU regardless of if param is sharded.
|
|
self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")
|
|
|
|
only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
|
|
ctx = (
|
|
self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
|
|
if only_check_err else suppress()
|
|
)
|
|
with ctx:
|
|
# FSDP training
|
|
shard_loss = self._train_for_several_steps(
|
|
model, num_steps, autocast=False, lr=lr,
|
|
fsdp_cpu_offload=cpu_offload, save_model=save_model,
|
|
mixed_precision=mixed_precision,
|
|
)
|
|
# We only check for errors in the case we have the following setup:
|
|
# model = FSDP(model, cpu_offload=True)
|
|
# model = model.cuda()
|
|
# so skip the rest of this logic.
|
|
if only_check_err:
|
|
return
|
|
# If CPU offload, next call will change model params to GPU. Sanity
|
|
# check that params are on CPU before.
|
|
if cpu_offload.offload_params:
|
|
device_set = {p.device for p in model.parameters()}
|
|
self.assertEqual(
|
|
{torch.device("cpu")},
|
|
device_set,
|
|
f"Got device set {device_set}"
|
|
)
|
|
shard_full_params = get_full_params(model)
|
|
|
|
if cpu_offload.offload_params:
|
|
shard_loss = shard_loss.cuda()
|
|
torch.testing.assert_allclose(ref_loss, shard_loss)
|
|
# Note that we don't do parameter check when testing mixed precision,
|
|
# as FSDP will bring the full param back to fp32 but we did model.half()
|
|
# for DDP so they wouldn't be equal. Further, DDP + model.half() would
|
|
# run optimizer in reduced precision versus FSDP's full precision.
|
|
if not mixed_precision:
|
|
self.assertEqual(
|
|
ref_full_params,
|
|
shard_full_params,
|
|
exact_device=True,
|
|
msg="FullyShardedDataParallel didn't match PyTorch DDP",
|
|
)
|
|
|
|
def _get_wrapped_model(
|
|
self, group, cuda_first=False, config=None, **model_kwargs,
|
|
) -> FullyShardedDataParallel:
|
|
if config is None:
|
|
config = {}
|
|
move_to_cuda = not (
|
|
"cpu_offload" in config and config["cpu_offload"].offload_params
|
|
)
|
|
if cuda_first:
|
|
transformer = TransformerWithSharedParams(group, **model_kwargs)
|
|
if move_to_cuda:
|
|
transformer = transformer.cuda()
|
|
model = FullyShardedDataParallel(transformer, group, **config)
|
|
else:
|
|
model = FullyShardedDataParallel(
|
|
TransformerWithSharedParams(group, **model_kwargs),
|
|
group,
|
|
**config,
|
|
)
|
|
if move_to_cuda:
|
|
model = model.cuda()
|
|
return model
|
|
|
|
def _get_nonwrapped_model(
|
|
self, group, **model_kwargs,
|
|
) -> torch.nn.Module:
|
|
"""Returns the non-wrapped model that is wrapped in
|
|
:meth:`_get_wrapped_model`. The model used in these two methods should
|
|
be kept in sync for tests that use both for parity comparisons."""
|
|
return TransformerWithSharedParams(group, **model_kwargs).cuda()
|
|
|
|
|
|
class SkipModule(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lin = nn.Linear(10, 10, bias=False)
|
|
|
|
def forward(self, x):
|
|
return self.lin(x)
|
|
|
|
|
|
class NestedLinear(nn.Module):
|
|
def __init__(self, fsdp_wrap):
|
|
super().__init__()
|
|
if fsdp_wrap:
|
|
self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda())
|
|
else:
|
|
self.nested_linear = nn.Linear(10, 10, bias=False).cuda()
|
|
|
|
def forward(self, x):
|
|
return self.nested_linear(x)
|
|
|
|
|
|
class SkipModel(nn.Module):
|
|
def __init__(self, double_nest):
|
|
super().__init__()
|
|
self.linear = nn.Linear(10, 10, bias=False).cuda()
|
|
self.linear_skip = SkipModule().cuda()
|
|
self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.linear_skip(x)
|
|
x = self.nested_linear(x)
|
|
return x
|
|
|
|
|
|
def _collect_total_grad_norm_fsdp(model, norm_type, rank):
|
|
total_norm = _collect_total_grad_norm_local(model, norm_type)
|
|
op = torch.distributed.ReduceOp.SUM
|
|
if norm_type == inf:
|
|
op = torch.distributed.ReduceOp.MAX
|
|
norm_type = 1.0
|
|
return_norm = torch.tensor(total_norm ** norm_type, device=rank)
|
|
dist.all_reduce(return_norm, op=op)
|
|
return return_norm ** (1.0 / norm_type)
|
|
|
|
|
|
def _collect_total_grad_norm_local(model, norm_type):
|
|
if norm_type == inf:
|
|
return max(p.grad.abs().max() for p in model.parameters())
|
|
else:
|
|
total_norm = 0.0
|
|
for p in model.parameters():
|
|
local_norm = torch.linalg.norm(p.grad, norm_type, dtype=torch.float32)
|
|
total_norm += local_norm ** norm_type
|
|
return total_norm ** (1.0 / norm_type)
|