mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Sometimes we randomly see unrelated FSDP CI failures such as https://github.com/pytorch/pytorch/runs/6298275361?check_suite_focus=true which are unrelated to the diff at hand. Suspicion is that because some other tests set `BACKEND` which is a generic env var for distributed tests, if those tests are run in same CI container before, they won't get unset and we'll use gloo for FSDP backend. But gloo is not currently supported, and this was mostly added for easy testing during early FSDP development, so remove this entirely. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76878 Approved by: https://github.com/awgu
693 lines
25 KiB
Python
693 lines
25 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
from contextlib import suppress
|
|
from copy import deepcopy
|
|
from enum import Enum
|
|
from math import inf
|
|
from typing import Union
|
|
from unittest import mock
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_
|
|
from torch.testing._internal.common_distributed import (
|
|
TEST_SKIPS,
|
|
MultiProcessTestCase,
|
|
)
|
|
from torch.distributed.fsdp.wrap import wrap
|
|
from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms
|
|
|
|
|
|
class FSDPInitMode(Enum):
|
|
# Move model to CUDA before wrap
|
|
CUDA_BEFORE = 1
|
|
# Move model to CUDA after wrap
|
|
CUDA_AFTER = 2
|
|
# Don't move model to CUDA at all.
|
|
CUDA_NEVER = 3
|
|
|
|
def _get_full_detached_param(fsdp_model: FullyShardedDataParallel):
|
|
with FullyShardedDataParallel.summon_full_params(fsdp_model):
|
|
params = list(p.clone().detach_() for p in fsdp_model.parameters())
|
|
|
|
return params
|
|
|
|
def _zero_model(fsdp_model: FullyShardedDataParallel):
|
|
with FullyShardedDataParallel.summon_full_params(fsdp_model):
|
|
for param in fsdp_model.parameters():
|
|
with torch.no_grad():
|
|
param.zero_()
|
|
|
|
def _get_state_dict(model, cpu_offload=False, half=False):
|
|
if not cpu_offload:
|
|
model = model.cuda()
|
|
if half:
|
|
model.half()
|
|
|
|
return model.state_dict()
|
|
|
|
def subtest_name(test_name_mapping, *args):
|
|
return '_'.join(
|
|
[test_name_mapping[str(s)] if s is not None else "none" for s in args]
|
|
)
|
|
|
|
# get full params of a model recursively. Note that if CPU offloading, it will
|
|
# also automatically move the parameters to GPU, due to _rebuild_full_params
|
|
# call.
|
|
def get_full_params(model, recurse=True):
|
|
with FullyShardedDataParallel.summon_full_params(model, recurse=recurse):
|
|
return deepcopy(list(model.parameters()))
|
|
|
|
def _maybe_cuda(model, move_to_cuda):
|
|
return model.cuda() if move_to_cuda else model
|
|
|
|
def _maybe_wrap_fsdp(model, wrap_fsdp, *args, **kwargs):
|
|
return (
|
|
model if not wrap_fsdp
|
|
else FullyShardedDataParallel(model, *args, **kwargs)
|
|
)
|
|
|
|
class DummyProcessGroup:
|
|
def __init__(self, rank: int, size: int):
|
|
self._rank = rank
|
|
self._size = size
|
|
|
|
def rank(self) -> int:
|
|
return self._rank
|
|
|
|
def size(self) -> int:
|
|
return self._size
|
|
|
|
class DeterministicModel(torch.nn.Module):
|
|
def __init__(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)):
|
|
super().__init__()
|
|
# keep everything deterministic for model initialization
|
|
torch.manual_seed(0)
|
|
self.inner: Union[torch.nn.Linear, FullyShardedDataParallel] = \
|
|
torch.nn.Linear(2, 2).cuda()
|
|
if wrap_fsdp:
|
|
self.inner = FullyShardedDataParallel(self.inner, cpu_offload=cpu_offload)
|
|
self.outer = torch.nn.Linear(2, 2).cuda()
|
|
|
|
def forward(self, x):
|
|
y = self.inner(x)
|
|
return self.outer(y)
|
|
|
|
class TransformerWithSharedParams(nn.Module):
|
|
def __init__(
|
|
self, group, *args, d_vocab=23, d_model=16, add_bn=True,
|
|
fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs
|
|
):
|
|
super().__init__()
|
|
self.rank = group.rank()
|
|
self.world_size = group.size()
|
|
torch.manual_seed(0) # keep everything deterministic
|
|
assert (
|
|
d_vocab >= 12
|
|
), "dim of vocab should be larger than 12, as we use torch.arange(12) as input"
|
|
|
|
self.embed_tokens = nn.Embedding(d_vocab, d_model)
|
|
self.transformer = nn.Transformer(
|
|
d_model=d_model,
|
|
num_encoder_layers=2,
|
|
num_decoder_layers=2,
|
|
dim_feedforward=8,
|
|
dropout=0.1,
|
|
)
|
|
self.output_proj = nn.Linear(d_model, d_vocab)
|
|
|
|
# share the embedding and output projection weights
|
|
self.output_proj.weight = self.embed_tokens.weight
|
|
self.register_buffer(
|
|
"vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
|
|
)
|
|
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long)) # type: ignore[arg-type]
|
|
|
|
self.bs = 2
|
|
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
|
|
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
|
self = _maybe_cuda(self, move_to_cuda)
|
|
|
|
def get_input(self, device):
|
|
torch.manual_seed(1 + self.rank) # keep everything deterministic
|
|
src = torch.arange(12, device=device).view(6, self.bs) # T x B
|
|
tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
|
|
return (src, tgt)
|
|
|
|
def forward(self, src_ids, tgt_ids):
|
|
src = self.embed_tokens(src_ids)
|
|
src = src + self.vocab_bias + self.long_buffer.type_as(src) # type: ignore[operator]
|
|
tgt = self.embed_tokens(tgt_ids)
|
|
tgt = self.bn(tgt)
|
|
x = self.transformer(src, tgt)
|
|
return self.output_proj(x)
|
|
|
|
def get_loss(self, input, output):
|
|
_, tgt = input
|
|
return nn.functional.cross_entropy(
|
|
output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
|
|
)
|
|
|
|
def run_backward(self, loss):
|
|
loss.backward()
|
|
|
|
def get_ignored_modules(self):
|
|
return [self.transformer]
|
|
|
|
|
|
class NestedWrappedModule(nn.Module):
|
|
def __init__(self, group, wrap_fsdp, *args, wrap_everything=False, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs):
|
|
super().__init__()
|
|
self.rank = group.rank()
|
|
self.world_size = group.size()
|
|
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
|
|
|
def _maybe_wrap(layer):
|
|
if wrap_fsdp:
|
|
return FullyShardedDataParallel(layer, group, *args, **kwargs)
|
|
return layer
|
|
|
|
torch.manual_seed(0) # keep everything deterministic
|
|
|
|
if wrap_everything:
|
|
self.module = nn.Sequential(
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(8, 4), move_to_cuda)),
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(4, 8), move_to_cuda)),
|
|
)
|
|
else:
|
|
self.module = nn.Sequential(
|
|
_maybe_cuda(nn.Linear(8, 4), move_to_cuda),
|
|
_maybe_wrap(
|
|
nn.Sequential(
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
|
|
_maybe_cuda(nn.Linear(16, 16), move_to_cuda),
|
|
),
|
|
),
|
|
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
|
|
_maybe_cuda(nn.Linear(4, 8), move_to_cuda),
|
|
)
|
|
|
|
def get_input(self, device):
|
|
torch.manual_seed(1 + self.rank) # keep everything deterministic
|
|
return (torch.rand(4, 8, device=device),)
|
|
|
|
def forward(self, x):
|
|
return self.module(x)
|
|
|
|
def get_loss(self, input, output):
|
|
loss = output.sum()
|
|
return loss
|
|
|
|
def run_backward(self, loss):
|
|
loss.backward()
|
|
|
|
|
|
class ModuleWithDelay(nn.Module):
|
|
def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0):
|
|
super().__init__()
|
|
self.delay_after_loss_ms = delay_after_loss_ms
|
|
self.delay_before_reduction_ms = delay_before_reduction_ms
|
|
self.module = module
|
|
|
|
def get_input(self, device):
|
|
return self.module.get_input(device)
|
|
|
|
def forward(self, x):
|
|
return self.module(x)
|
|
|
|
def get_loss(self, input, output):
|
|
loss = self.module.get_loss(input, output)
|
|
if self.delay_after_loss_ms > 0:
|
|
torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
|
|
return loss
|
|
|
|
def run_backward(self, loss):
|
|
orig_reduce_scatter = torch.distributed._reduce_scatter_base
|
|
|
|
def _delayed_reduce_scatter(*args, **kwargs):
|
|
if self.delay_before_reduction_ms > 0:
|
|
torch.cuda._sleep(
|
|
int(self.delay_before_reduction_ms * get_cycles_per_ms())
|
|
)
|
|
return orig_reduce_scatter(*args, **kwargs)
|
|
|
|
with mock.patch(
|
|
"torch.distributed._reduce_scatter_base", _delayed_reduce_scatter
|
|
):
|
|
self.module.run_backward(loss)
|
|
|
|
|
|
class NestedWrappedModuleWithDelay(ModuleWithDelay):
|
|
def __init__(
|
|
self,
|
|
group,
|
|
wrap_fsdp,
|
|
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
|
|
cpu_offload=None,
|
|
backward_prefetch=None,
|
|
sharding_strategy=None,
|
|
mixed_precision=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
NestedWrappedModule(
|
|
group,
|
|
wrap_fsdp,
|
|
fsdp_init_mode=fsdp_init_mode,
|
|
cpu_offload=cpu_offload,
|
|
backward_prefetch=backward_prefetch,
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
),
|
|
**kwargs
|
|
)
|
|
|
|
|
|
class DummyDDP(nn.Module):
|
|
def __init__(self, module):
|
|
super().__init__()
|
|
self.module = module
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.module(*args, **kwargs)
|
|
|
|
|
|
class MixtureOfExperts(NestedWrappedModule):
|
|
def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, **kwargs):
|
|
super().__init__(group, wrap_fsdp)
|
|
self.group = group
|
|
self.delay_before_free_ms = delay_before_free_ms
|
|
self.wrap_fsdp = wrap_fsdp
|
|
self.move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
|
# "expert" params are different on each rank
|
|
torch.manual_seed(42 + group.rank())
|
|
d_expert = 23
|
|
d_shared = 12
|
|
d_input = 8
|
|
expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)
|
|
|
|
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
|
|
for p in expert.parameters():
|
|
p.expert = True # type: ignore[attr-defined]
|
|
|
|
# everything else is shared
|
|
torch.manual_seed(0)
|
|
|
|
shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)
|
|
|
|
if wrap_fsdp:
|
|
# we create a process group of size 1 for the expert params
|
|
expert_group = torch.distributed.new_group(
|
|
[group.rank()]
|
|
) # world size 1 means no shard
|
|
expert = FullyShardedDataParallel(expert, expert_group, **kwargs) # type: ignore[assignment]
|
|
|
|
shared = FullyShardedDataParallel(shared, group, **kwargs) # type: ignore[assignment]
|
|
|
|
self.module = nn.Sequential(
|
|
_maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
|
|
shared,
|
|
expert,
|
|
_maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda)
|
|
)
|
|
|
|
def forward(self, x):
|
|
if self.delay_before_free_ms > 0:
|
|
expert = self.module[2]
|
|
if isinstance(expert, FullyShardedDataParallel):
|
|
orig_free_full_params = self.module[2]._free_full_params
|
|
|
|
def _free_full_params_with_delay(*args):
|
|
torch.cuda._sleep(
|
|
int(self.delay_before_free_ms * get_cycles_per_ms())
|
|
)
|
|
return orig_free_full_params(*args)
|
|
|
|
assert hasattr(
|
|
expert, "_free_full_params"
|
|
), "expert FSDP module should has _free_full_params attribute."
|
|
with mock.patch.object(
|
|
expert, "_free_full_params", _free_full_params_with_delay
|
|
):
|
|
return self.module(x)
|
|
|
|
return self.module(x)
|
|
|
|
def run_backward(self, loss):
|
|
loss.backward()
|
|
|
|
# manually reduce gradients if not wrapped in FullyShardedDataParallel
|
|
if not self.wrap_fsdp:
|
|
with torch.no_grad():
|
|
for p in self.parameters():
|
|
if hasattr(p, "expert"):
|
|
continue # these params don't need grad reduction
|
|
p.grad.div_(self.world_size)
|
|
torch.distributed.all_reduce(p.grad, group=self.group)
|
|
|
|
|
|
class FSDPTest(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super(FSDPTest, self).setUp()
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def world_size(self):
|
|
return torch.cuda.device_count() if torch.cuda.is_available() else 4
|
|
|
|
@property
|
|
def init_method(self):
|
|
return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name)
|
|
|
|
def _check_cpu_offload(self, fsdp_model, cpu_offload):
|
|
self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
|
|
|
|
def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
|
|
self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
|
|
|
|
@classmethod
|
|
def _run(cls, rank, test_name, file_name, pipe):
|
|
self = cls(test_name)
|
|
self.rank = rank
|
|
self.file_name = file_name
|
|
|
|
print(f"dist init r={self.rank}, world={self.world_size}")
|
|
|
|
# Specify gloo backend to make 'init_process_group()' succeed,
|
|
# Actual tests will be skipped if there is no enough GPUs.
|
|
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
|
|
|
try:
|
|
dist.init_process_group(
|
|
init_method=self.init_method,
|
|
backend=backend,
|
|
world_size=int(self.world_size),
|
|
rank=self.rank,
|
|
)
|
|
except RuntimeError as e:
|
|
if "recompile" in e.args[0]:
|
|
sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
|
|
|
|
raise
|
|
|
|
if torch.cuda.is_available() and torch.cuda.device_count():
|
|
torch.cuda.set_device(self.rank % torch.cuda.device_count())
|
|
|
|
# Execute barrier prior to running test to ensure that every process
|
|
# has finished initialization and that the following test
|
|
# immediately exiting due to a skip doesn't cause flakiness.
|
|
dist.barrier()
|
|
|
|
self.run_test(test_name, pipe)
|
|
|
|
dist.barrier()
|
|
|
|
dist.destroy_process_group()
|
|
sys.exit(0)
|
|
|
|
def _train_for_several_steps(
|
|
self,
|
|
model,
|
|
num_steps,
|
|
autocast,
|
|
lr=0.01,
|
|
fsdp_cpu_offload=None,
|
|
clip_norm=0.3,
|
|
norm_type=None,
|
|
save_model=False,
|
|
mixed_precision=None
|
|
):
|
|
cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
|
|
|
|
model_device = next(model.parameters()).device
|
|
# use SGD with momentum instead of Adam, since Adam is scale invariant
|
|
# and this makes it bad for tests
|
|
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
|
|
for _ in range(num_steps):
|
|
optim.zero_grad()
|
|
with torch.cuda.amp.autocast(enabled=autocast):
|
|
# Inputs always cuda regardless of cpu offloading, or model.device
|
|
input = model.module.get_input(torch.device("cuda"))
|
|
if mixed_precision and not isinstance(model, FullyShardedDataParallel):
|
|
if isinstance(input, torch.Tensor):
|
|
input = input.half()
|
|
else:
|
|
input = tuple(x.half() for x in input)
|
|
output = model(*input)
|
|
# Post-forward, if CPU offloading model param should be on CPU.
|
|
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
|
|
for p in model.parameters():
|
|
# Params should always be on CPU, even if
|
|
# p._is_sharded=False
|
|
self.assertEqual(p.device, torch.device("cpu"))
|
|
|
|
loss = model.module.get_loss(input, output).to(model_device)
|
|
if not mixed_precision:
|
|
assert (
|
|
loss.dtype == torch.float32
|
|
), "loss data type should be float32, as the original \
|
|
parameter data type is float32."
|
|
else:
|
|
# FSDP loss is fp16, DDP AMP loss is fp32
|
|
if isinstance(model, FullyShardedDataParallel):
|
|
self.assertEqual(loss.dtype, mixed_precision.param_dtype)
|
|
else:
|
|
self.assertEqual(loss.dtype, torch.float32)
|
|
model.module.run_backward(loss)
|
|
if norm_type is not None:
|
|
if isinstance(model, FullyShardedDataParallel):
|
|
model.clip_grad_norm_(clip_norm, norm_type)
|
|
total_norm_after_clip = _collect_total_grad_norm_fsdp(
|
|
model, norm_type, self.rank
|
|
)
|
|
else:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
|
|
total_norm_after_clip = _collect_total_grad_norm_local(
|
|
model, norm_type
|
|
)
|
|
self.assertTrue(total_norm_after_clip <= clip_norm)
|
|
# Post-backward, if CPU offloading model params should be on CPU.
|
|
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
|
|
for p in model.parameters():
|
|
# Params should always be on CPU, even if
|
|
# p._is_sharded=False
|
|
self.assertEqual(p.device, torch.device("cpu"))
|
|
optim.step()
|
|
# if save_model, simulate save + load.
|
|
|
|
if save_model:
|
|
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
|
|
# Zero params, if save/load state_dict did not work properly, this
|
|
# would break the parity test with DDP.
|
|
_zero_model(model)
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
if isinstance(model, FullyShardedDataParallel):
|
|
model._assert_state(TrainingState_.IDLE)
|
|
return loss.detach()
|
|
|
|
def _test_identical_outputs(
|
|
self,
|
|
model_init_fn,
|
|
*args,
|
|
ref_ddp_fn=None,
|
|
num_steps=2,
|
|
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
|
|
lr=0.01,
|
|
cpu_offload=CPUOffload(),
|
|
backward_prefetch=None,
|
|
sharding_strategy=None,
|
|
mixed_precision=None,
|
|
save_model=True,
|
|
clip_norm=0.3,
|
|
norm_type=None,
|
|
**kwargs
|
|
):
|
|
group = dist.distributed_c10d._get_default_group()
|
|
rank = group.rank()
|
|
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
|
|
model = model_init_fn(group=group, wrap_fsdp=False).cuda()
|
|
if ref_ddp_fn is None:
|
|
model = nn.parallel.DistributedDataParallel(
|
|
model, device_ids=[rank], output_device=rank
|
|
)
|
|
else:
|
|
model = ref_ddp_fn(model)
|
|
|
|
# DDP training
|
|
ref_loss = self._train_for_several_steps(
|
|
model, num_steps, autocast=mixed_precision is not None, lr=lr,
|
|
fsdp_cpu_offload=cpu_offload, mixed_precision=mixed_precision,
|
|
)
|
|
ref_full_params = list(model.parameters())
|
|
|
|
# Confirm we get the same behavior using FullyShardedDataParallel.
|
|
try:
|
|
model = model_init_fn(
|
|
group=group,
|
|
wrap_fsdp=True,
|
|
fsdp_init_mode=fsdp_init_mode,
|
|
cpu_offload=cpu_offload,
|
|
backward_prefetch=backward_prefetch,
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")
|
|
|
|
cpu_offload = cpu_offload or CPUOffload() # disabled if not specified.
|
|
model = FullyShardedDataParallel(
|
|
model,
|
|
cpu_offload=cpu_offload,
|
|
backward_prefetch=backward_prefetch,
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
)
|
|
# Call model.cuda() after init FSDP if specified.
|
|
if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
|
|
model = model.cuda()
|
|
|
|
# Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
|
|
# expect FSDP code to raise error that we check below, in the case of
|
|
# offload params.
|
|
if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
|
|
for p in model.parameters():
|
|
# Should be on CPU regardless of if param is sharded.
|
|
self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")
|
|
|
|
only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
|
|
ctx = (
|
|
self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
|
|
if only_check_err else suppress()
|
|
)
|
|
with ctx:
|
|
# FSDP training
|
|
shard_loss = self._train_for_several_steps(
|
|
model, num_steps, autocast=False, lr=lr,
|
|
fsdp_cpu_offload=cpu_offload, save_model=save_model,
|
|
mixed_precision=mixed_precision,
|
|
)
|
|
# We only check for errors in the case we have the following setup:
|
|
# model = FSDP(model, cpu_offload=True)
|
|
# model = model.cuda()
|
|
# so skip the rest of this logic.
|
|
if only_check_err:
|
|
return
|
|
# If CPU offload, next call will change model params to GPU. Sanity
|
|
# check that params are on CPU before.
|
|
if cpu_offload.offload_params:
|
|
device_set = {p.device for p in model.parameters()}
|
|
self.assertEqual(
|
|
{torch.device("cpu")},
|
|
device_set,
|
|
f"Got device set {device_set}"
|
|
)
|
|
shard_full_params = get_full_params(model)
|
|
|
|
if cpu_offload.offload_params:
|
|
shard_loss = shard_loss.cuda()
|
|
torch.testing.assert_allclose(ref_loss, shard_loss)
|
|
# Note that we don't do parameter check when testing mixed precision,
|
|
# as FSDP will bring the full param back to fp32 but we did model.half()
|
|
# for DDP so they wouldn't be equal. Further, DDP + model.half() would
|
|
# run optimizer in reduced precision versus FSDP's full precision.
|
|
if not mixed_precision:
|
|
self.assertEqual(
|
|
ref_full_params,
|
|
shard_full_params,
|
|
exact_device=True,
|
|
msg="FullyShardedDataParallel didn't match PyTorch DDP",
|
|
)
|
|
|
|
def _get_wrapped_model(
|
|
self, group, cuda_first=False, ignore_modules=False, config=None,
|
|
**model_kwargs,
|
|
) -> FullyShardedDataParallel:
|
|
if config is None:
|
|
config = {}
|
|
move_to_cuda = not (
|
|
"cpu_offload" in config and config["cpu_offload"].offload_params
|
|
)
|
|
transformer = TransformerWithSharedParams(group, **model_kwargs)
|
|
if cuda_first and move_to_cuda:
|
|
transformer = transformer.cuda()
|
|
if ignore_modules:
|
|
assert "ignored_modules" not in config, \
|
|
"Do not pass in `ignored_modules` via `config`"
|
|
config["ignored_modules"] = transformer.get_ignored_modules()
|
|
model = FullyShardedDataParallel(transformer, group, **config)
|
|
if not cuda_first and move_to_cuda:
|
|
model = model.cuda()
|
|
return model
|
|
|
|
def _get_nonwrapped_model(
|
|
self, group, **model_kwargs,
|
|
) -> torch.nn.Module:
|
|
"""Returns the non-wrapped model that is wrapped in
|
|
:meth:`_get_wrapped_model`. The model used in these two methods should
|
|
be kept in sync for tests that use both for parity comparisons."""
|
|
return TransformerWithSharedParams(group, **model_kwargs).cuda()
|
|
|
|
|
|
class SkipModule(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lin = nn.Linear(10, 10, bias=False)
|
|
|
|
def forward(self, x):
|
|
return self.lin(x)
|
|
|
|
|
|
class NestedLinear(nn.Module):
|
|
def __init__(self, fsdp_wrap):
|
|
super().__init__()
|
|
if fsdp_wrap:
|
|
self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda())
|
|
else:
|
|
self.nested_linear = nn.Linear(10, 10, bias=False).cuda()
|
|
|
|
def forward(self, x):
|
|
return self.nested_linear(x)
|
|
|
|
|
|
class SkipModel(nn.Module):
|
|
def __init__(self, double_nest):
|
|
super().__init__()
|
|
self.linear = nn.Linear(10, 10, bias=False).cuda()
|
|
self.linear_skip = SkipModule().cuda()
|
|
self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.linear_skip(x)
|
|
x = self.nested_linear(x)
|
|
return x
|
|
|
|
|
|
def _collect_total_grad_norm_fsdp(model, norm_type, rank):
|
|
total_norm = _collect_total_grad_norm_local(model, norm_type)
|
|
op = torch.distributed.ReduceOp.SUM
|
|
if norm_type == inf:
|
|
op = torch.distributed.ReduceOp.MAX
|
|
norm_type = 1.0
|
|
return_norm = torch.tensor(total_norm ** norm_type, device=rank)
|
|
dist.all_reduce(return_norm, op=op)
|
|
return return_norm ** (1.0 / norm_type)
|
|
|
|
|
|
def _collect_total_grad_norm_local(model, norm_type):
|
|
if norm_type == inf:
|
|
return max(p.grad.abs().max() for p in model.parameters())
|
|
else:
|
|
total_norm = 0.0
|
|
for p in model.parameters():
|
|
local_norm = torch.linalg.norm(p.grad, norm_type, dtype=torch.float32)
|
|
total_norm += local_norm ** norm_type
|
|
return total_norm ** (1.0 / norm_type)
|