diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py index 468234ddf22..52b20334ae4 100644 --- a/test/distributed/fsdp/test_fsdp_core.py +++ b/test/distributed/fsdp/test_fsdp_core.py @@ -12,6 +12,7 @@ from torch.testing._internal.common_distributed import ( ) from torch.testing._internal.common_fsdp import ( DummyDDP, + FSDPInitMode, FSDPTest, MixtureOfExperts, NestedWrappedModule, @@ -25,6 +26,8 @@ from torch.testing._internal.common_utils import ( run_tests, ) +from torch.distributed._fsdp.fully_sharded_data_parallel import CPUOffload + if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) @@ -44,59 +47,146 @@ class TestParityWithDDP(FSDPTest): PyTorch DDP vs. FullyShardedDataParallel. """ - @skip_if_lt_x_gpu(2) - def test_nested_wrapped_model(self): - self._test_identical_outputs(NestedWrappedModule) + def _get_init_modes_for_test(self, cpu_offload): + modes = [ + FSDPInitMode.CUDA_AFTER, + FSDPInitMode.CUDA_BEFORE + ] + # Note that FSDPInitMode.CUDA_NEVER works currently only with CPU + # offload as we explicitly bring the param back to CUDA device. In + # general, it will not work since we try to all_gather p.data which is + # on CPU but NCCL only supports GPU. + if cpu_offload.offload_params: + modes.append(FSDPInitMode.CUDA_NEVER) + + return modes @skip_if_lt_x_gpu(2) - def test_nested_all_wrapped_model(self): - model_fn = functools.partial(NestedWrappedModule, wrap_everything=True) - self._test_identical_outputs(model_fn) + @parametrize( + "cpu_offload", + [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] + ) + def test_nested_wrapped_model(self, cpu_offload): + init_modes = self._get_init_modes_for_test(cpu_offload) + for fsdp_init_mode in init_modes: + with self.subTest(fsdp_init_mode=fsdp_init_mode): + self._test_identical_outputs( + NestedWrappedModule, + fsdp_init_mode=fsdp_init_mode, + cpu_offload=cpu_offload + ) @skip_if_lt_x_gpu(2) - def test_transformer_parameterized(self): - self._test_identical_outputs(TransformerWithSharedParams) + @parametrize( + "cpu_offload", + [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] + ) + def test_nested_all_wrapped_model(self, cpu_offload): + init_modes = self._get_init_modes_for_test(cpu_offload) + for fsdp_init_mode in init_modes: + with self.subTest(fsdp_init_mode=fsdp_init_mode): + model_fn = functools.partial(NestedWrappedModule, wrap_everything=True) + self._test_identical_outputs( + model_fn, + fsdp_init_mode=fsdp_init_mode, + cpu_offload=cpu_offload + ) @skip_if_lt_x_gpu(2) - def test_delayed_optim_step(self): + @parametrize( + "cpu_offload", + [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] + ) + def test_transformer_parameterized(self, cpu_offload): + init_modes = self._get_init_modes_for_test(cpu_offload) + for fsdp_init_mode in init_modes: + with self.subTest(fsdp_init_mode=fsdp_init_mode): + self._test_identical_outputs( + TransformerWithSharedParams, + fsdp_init_mode=fsdp_init_mode, + cpu_offload=cpu_offload + ) + + @skip_if_lt_x_gpu(2) + @parametrize( + "cpu_offload", + [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] + ) + def test_delayed_optim_step(self, cpu_offload): # We use a model with a long CUDA delay right before the optimizer step. # This tests our streams logic, and that we don't start the allgather # until after the optimization step completes. - model_fn = functools.partial( - NestedWrappedModuleWithDelay, delay_after_loss_ms=250 - ) - self._test_identical_outputs(model_fn) + init_modes = self._get_init_modes_for_test(cpu_offload) + for fsdp_init_mode in init_modes: + with self.subTest(fsdp_init_mode=fsdp_init_mode): + model_fn = functools.partial( + NestedWrappedModuleWithDelay, delay_after_loss_ms=250 + ) + self._test_identical_outputs( + model_fn, + fsdp_init_mode=fsdp_init_mode, + cpu_offload=cpu_offload + ) @skip_if_lt_x_gpu(2) - def test_delayed_reduce_scatter(self): + @parametrize( + "cpu_offload", + [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] + ) + def test_delayed_reduce_scatter(self, cpu_offload): # We insert a delay in the torch.distributed._reduce_scatter_base op, so that # the post_backward_stream takes much longer than the backward pass. # This tests that we properly block at the end of the backward pass for # the reductions to finish. - model_fn = functools.partial( - NestedWrappedModuleWithDelay, delay_before_reduction_ms=250 - ) - self._test_identical_outputs(model_fn) + init_modes = self._get_init_modes_for_test(cpu_offload) + for fsdp_init_mode in init_modes: + with self.subTest(fsdp_init_mode=fsdp_init_mode): + model_fn = functools.partial( + NestedWrappedModuleWithDelay, delay_before_reduction_ms=250 + ) + self._test_identical_outputs( + model_fn, + fsdp_init_mode=fsdp_init_mode, + cpu_offload=cpu_offload + ) def _dummy_ddp_fn(self, model): return DummyDDP(model) @skip_if_lt_x_gpu(2) - def test_mixture_of_experts(self): - self._test_identical_outputs( - MixtureOfExperts, - # MixtureOfExperts implements custom reduce logic, so the reference - # behavior should use that logic instead of PyTorch DDP. - ref_ddp_fn=self._dummy_ddp_fn, - ) + @parametrize( + "cpu_offload", + [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] + ) + def test_mixture_of_experts(self, cpu_offload): + init_modes = self._get_init_modes_for_test(cpu_offload) + for fsdp_init_mode in init_modes: + with self.subTest(fsdp_init_mode=fsdp_init_mode): + self._test_identical_outputs( + MixtureOfExperts, + # MixtureOfExperts implements custom reduce logic, so the reference + # behavior should use that logic instead of PyTorch DDP. + ref_ddp_fn=self._dummy_ddp_fn, + fsdp_init_mode=fsdp_init_mode, + cpu_offload=cpu_offload, + ) @skip_if_lt_x_gpu(2) - def test_mixture_of_experts_with_delay_before_free(self): - model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250) - self._test_identical_outputs( - model_fn, - ref_ddp_fn=self._dummy_ddp_fn, - ) + @parametrize( + "cpu_offload", + [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] + ) + def test_mixture_of_experts_with_delay_before_free(self, cpu_offload): + init_modes = self._get_init_modes_for_test(cpu_offload) + for fsdp_init_mode in init_modes: + with self.subTest(fsdp_init_mode=fsdp_init_mode): + model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250) + self._test_identical_outputs( + model_fn, + ref_ddp_fn=self._dummy_ddp_fn, + fsdp_init_mode=fsdp_init_mode, + cpu_offload=cpu_offload, + ) class TestParamInit(FSDPTest): @@ -194,6 +284,7 @@ class TestNoGrad(FSDPTest): instantiate_parametrized_tests(TestHooks) +instantiate_parametrized_tests(TestParityWithDDP) if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_fsdp/fully_sharded_data_parallel.py b/torch/distributed/_fsdp/fully_sharded_data_parallel.py index d37458e5abe..4cc8bd1e73d 100644 --- a/torch/distributed/_fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/_fsdp/fully_sharded_data_parallel.py @@ -1,5 +1,6 @@ import functools import traceback +from dataclasses import dataclass from enum import Enum, auto from typing import ( TYPE_CHECKING, @@ -30,6 +31,14 @@ if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 + +@dataclass +class CPUOffload: + offload_params: bool = False + # TODO: state dict offloading, activation offloading + # https://github.com/pytorch/pytorch/issues/67224 + + class TrainingState_(Enum): """ Simple enum to indicate what state FSDP is in. Used for asserting @@ -78,12 +87,20 @@ class FullyShardedDataParallel(nn.Module): module to be wrapped with FSDP. process_group (Optional[ProcessGroup]): process group for sharding + cpu_offload (Optional [CPUOffload]): + CPU offloading config. Currently, only parameter and grad CPU + offload is supported. It can be enabled via passing in + cpu_offload=CPUOffload(offload_params=True). Note that this + currently implicitly enables gradient offloading to CPU in order for + params and grads to be on same device to work with optimizer. This + API is subject to change. """ def __init__( self, module: nn.Module, process_group: Optional[ProcessGroup] = None, + cpu_offload: Optional[CPUOffload] = None ): torch._C._log_api_usage_once("torch.distributed.fsdp") super().__init__() @@ -107,6 +124,7 @@ class FullyShardedDataParallel(nn.Module): ) self.numel_padded_per_param: List[int] = [] + self.cpu_offload = cpu_offload or CPUOffload() # Only handle params which are not already sharded. This enables # sharding individual layers of a Module, with an outer wrapper to @@ -140,6 +158,11 @@ class FullyShardedDataParallel(nn.Module): # Flag to guard against preparing gradients multiple times per backward pass. self._pre_backward_hook_has_run = False + # If specified, offload parameter shard to CPU. + if self.cpu_offload.offload_params: + for p in self.params: + self._offload_to_cpu(p) + @property def module(self) -> FlattenParamsWrapper: """make model.module accessible, just like DDP.""" @@ -154,6 +177,17 @@ class FullyShardedDataParallel(nn.Module): factor *= 2 return float(factor) + def _offload_to_cpu(self, p): + """ + Offloads parameter to CPU from self.compute_device. If the parameter is + already on CPU then this is a noop. + """ + cpu_device = torch.device("cpu") + if p.device == cpu_device: + return + with torch.no_grad(): + p.data = p.to(cpu_device) + def _cast_buffers( self, device: Optional[torch.device] = None, memo: Optional[Set] = None ) -> None: @@ -309,7 +343,9 @@ class FullyShardedDataParallel(nn.Module): # Don't free the full params for the outer-most (root) instance, # In most cases, root instance contains params in the last layers # or has no params. In these cases, those params will be needed - # immediately after for the backward pass. + # immediately after for the backward pass. Note that this only + # applies currently when freeing parameters at end of layer's + # forward pass. self.reshard_after_forward = False # Due to the use of streams, we need to make sure the previous @@ -343,10 +379,33 @@ class FullyShardedDataParallel(nn.Module): p, "_orig_size" ), "Parameters should have been sharded during construction." if hasattr(p, "_local_shard"): + # If CPU offloading, p._local_shard should have been placed on CPU + # during its first lazy construction. + if self.cpu_offload.offload_params: + assert ( + p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined] + ), "Expected p._local_shard to be on CPU." # type: ignore[attr-defined] return - # A single shard of the parameters. + # A single shard of the parameters. Also makes p._local_shard to be on + # CPU if we are CPU offloading, since p.data would be on CPU during + # init. + if self.cpu_offload.offload_params: + assert p.device == torch.device("cpu"), "Expected param to be on CPU when cpu_offloading is enabled." p._local_shard = p.data # type: ignore[attr-defined] + # If CPU offloading, pin the memory to enable faster CPU -> GPU device + # transfer. + if self.cpu_offload.offload_params: + assert p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined] + p._local_shard.pin_memory() # type: ignore[attr-defined] + # When offloading parameters, also move the grad shard to CPU during + # backward pass. In this case, it's important to pre-allocate the + # CPU grad shard in pinned memory so that we can do a non-blocking + # transfer. + p._cpu_grad = torch.zeros_like( # type: ignore[attr-defined] + p, + device=torch.device("cpu") + ).pin_memory() # We also maintain a full-sized parameter of type self.compute_dtype. # We resize the storage to size 0 at init (here) and only materialize @@ -422,7 +481,8 @@ class FullyShardedDataParallel(nn.Module): # Start of a forward pass. self.training_state = TrainingState_.FORWARD - # All-gather full parameters. + # All-gather full parameters, moving them to compute_device if + # necessary. self._rebuild_full_params() # Register backward hooks to reshard params and reduce-scatter grads. @@ -480,6 +540,8 @@ class FullyShardedDataParallel(nn.Module): if self._is_root: self._queue_wait_for_post_backward() + # All-gather full parameters, moving them to compute device if + # necessary. self._rebuild_full_params() self._pre_backward_hook_has_run = True @@ -617,6 +679,23 @@ class FullyShardedDataParallel(nn.Module): ), "Currently the only way for _is_sharded to be False is \ world_size == 1" + # Regardless of sharding or not, offload the grad to CPU if we are + # offloading params. This is so param and grad reside on same device + # which is needed for the optimizer step. + if self.cpu_offload.offload_params: + # We specify non_blocking=True + # and ensure the appropriate synchronization is done by waiting + # streams in _wait_for_post_backward. + param._cpu_grad.copy_( # type: ignore[attr-defined] + param.grad.detach(), non_blocking=True + ) + # Don't let this memory get reused until after the transfer. + param.grad.data.record_stream(torch.cuda.current_stream()) + # Point param.grad.data to CPU grad to offload it. Note that + # the transfer is async so it is not necessarily done until we + # explicitly synchronize in backward. + param.grad.data = param._cpu_grad # type: ignore[attr-defined] + # After _post_backward_hook returns, orig_grad_data will eventually # go out of scope, at which point it could otherwise be freed for # further reuse by the main stream while the div/reduce_scatter/copy @@ -651,6 +730,13 @@ class FullyShardedDataParallel(nn.Module): self._assert_state(TrainingState_.BACKWARD_PRE) torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) + if self.cpu_offload.offload_params: + # We need to wait for the non-blocking GPU -> + # CPU grad transfers to finish. TODO investigate why this is needed + # and if we can remove it, as we've done transfer on post_backward + # stream and synchronized it above. + torch.cuda.current_stream().synchronize() + # A backward pass is done, clean up below. def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None: @@ -712,6 +798,13 @@ class FullyShardedDataParallel(nn.Module): with torch.cuda.stream(self._streams["all_gather"]): for p in self.params: + if self.cpu_offload.offload_params: + # Move params to GPU if needed. Note that we don't use + # self._full_param_padded.device here because the attr is + # not set always, i.e. when world_size=1 and + # p._is_sharded = False. However when it is set, the + # device is always self.compute_device. + p.data = p.data.to(self.compute_device, non_blocking=True) # e.g., when world_size == 1 if not p._is_sharded: # type: ignore[attr-defined] continue @@ -724,7 +817,6 @@ class FullyShardedDataParallel(nn.Module): continue else: # If full param has not been rebuilt or has been freed, call all gather - # Move params in CPU to CUDA for the all-gather. p_data = p.data # type: ignore[attr-defined] p_full_size = p._full_param_padded.size() # type: ignore[attr-defined] assert ( @@ -759,7 +851,9 @@ class FullyShardedDataParallel(nn.Module): @torch.no_grad() def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: - """Free up storage for full parameters.""" + """ + Free up storage for full parameters. + """ if params is None: params = self.params current_stream = torch.cuda.current_stream() @@ -780,10 +874,16 @@ class FullyShardedDataParallel(nn.Module): @torch.no_grad() def _use_param_local_shard(self, params: Optional[List[Parameter]] = None) -> None: - """Use local shard for a list of params.""" + """Use local shard for a list of params. Also implicitly offloads + parameters back to CPU if we are CPU offloading.""" if params is None: params = self.params for p in params: + if self.cpu_offload.offload_params: + # Ensure local_shard resides in CPU if we are offloading params. + assert ( + p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined] + ), "Expected p._local_shard to be on CPU" # type: ignore[attr-defined] p.data = p._local_shard # type: ignore[attr-defined] def _assert_state(self, state: Union[TrainingState_, List[TrainingState_]]) -> None: diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index e76161535ba..019804b65d2 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +from contextlib import suppress import sys from unittest import mock @@ -7,7 +8,9 @@ import torch import torch.distributed as dist import torch.nn as nn from torch.distributed._fsdp import FullyShardedDataParallel -from torch.distributed._fsdp.fully_sharded_data_parallel import TrainingState_ +from torch.distributed._fsdp.fully_sharded_data_parallel import ( + TrainingState_, CPUOffload +) from torch.testing._internal.common_distributed import ( MultiProcessTestCase, TEST_SKIPS, @@ -17,7 +20,20 @@ from torch.testing._internal.common_utils import ( get_cycles_per_ms, ) -# get full params of a model recursively +from enum import Enum + + +class FSDPInitMode(Enum): + # Move model to CUDA before wrap + CUDA_BEFORE = 1 + # Move model to CUDA after wrap + CUDA_AFTER = 2 + # Don't move model to CUDA at all. + CUDA_NEVER = 3 + +# get full params of a model recursively. Note that if CPU offloading, it will +# also automatically move the parameters to GPU, due to _rebuild_full_params +# call. def get_full_params(model, recurse=True): if recurse: # get all params for any nested FSDP instances. @@ -30,6 +46,8 @@ def get_full_params(model, recurse=True): if model.module.flat_param is not None: model.module._unflatten_params() +def _maybe_cuda(model, move_to_cuda): + return model.cuda() if move_to_cuda else model class DummyProcessGroup: def __init__(self, rank: int, size: int): @@ -45,7 +63,8 @@ class DummyProcessGroup: class TransformerWithSharedParams(nn.Module): def __init__( - self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs + self, group, *args, d_vocab=23, d_model=16, add_bn=True, + fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs ): super().__init__() self.rank = group.rank() @@ -54,6 +73,7 @@ class TransformerWithSharedParams(nn.Module): assert ( d_vocab >= 12 ), "dim of vocab should be larger than 12, as we use torch.arange(12) as input" + self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( d_model=d_model, @@ -73,6 +93,8 @@ class TransformerWithSharedParams(nn.Module): self.bs = 2 self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity() + move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE + self = _maybe_cuda(self, move_to_cuda) def get_input(self, device): torch.manual_seed(1 + self.rank) # keep everything deterministic @@ -99,36 +121,37 @@ class TransformerWithSharedParams(nn.Module): class NestedWrappedModule(nn.Module): - def __init__(self, group, wrap_fsdp, wrap_everything=False): + def __init__(self, group, wrap_fsdp, *args, wrap_everything=False, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs): super().__init__() self.rank = group.rank() self.world_size = group.size() + move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE def _maybe_wrap(layer): if wrap_fsdp: - return FullyShardedDataParallel(layer, group) + return FullyShardedDataParallel(layer, group, *args, **kwargs) return layer torch.manual_seed(0) # keep everything deterministic if wrap_everything: self.module = nn.Sequential( - _maybe_wrap(nn.Linear(8, 4)), - _maybe_wrap(nn.Linear(4, 16)), - _maybe_wrap(nn.Linear(16, 4)), - _maybe_wrap(nn.Linear(4, 8)), + _maybe_wrap(_maybe_cuda(nn.Linear(8, 4), move_to_cuda)), + _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), + _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)), + _maybe_wrap(_maybe_cuda(nn.Linear(4, 8), move_to_cuda)), ) else: self.module = nn.Sequential( - nn.Linear(8, 4), + _maybe_cuda(nn.Linear(8, 4), move_to_cuda), _maybe_wrap( nn.Sequential( - _maybe_wrap(nn.Linear(4, 16)), - nn.Linear(16, 16), - ) + _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), + _maybe_cuda(nn.Linear(16, 16), move_to_cuda), + ), ), - _maybe_wrap(nn.Linear(16, 4)), - nn.Linear(4, 8), + _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)), + _maybe_cuda(nn.Linear(4, 8), move_to_cuda), ) def get_input(self, device): @@ -182,8 +205,23 @@ class ModuleWithDelay(nn.Module): class NestedWrappedModuleWithDelay(ModuleWithDelay): - def __init__(self, group, wrap_fsdp, **kwargs): - super().__init__(NestedWrappedModule(group, wrap_fsdp), **kwargs) + def __init__( + self, + group, + wrap_fsdp, + fsdp_init_mode=FSDPInitMode.CUDA_AFTER, + cpu_offload=None, + **kwargs + ): + super().__init__( + NestedWrappedModule( + group, + wrap_fsdp, + fsdp_init_mode=fsdp_init_mode, + cpu_offload=cpu_offload + ), + **kwargs + ) class DummyDDP(nn.Module): @@ -196,18 +234,18 @@ class DummyDDP(nn.Module): class MixtureOfExperts(NestedWrappedModule): - def __init__(self, group, wrap_fsdp, delay_before_free_ms=0): + def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, **kwargs): super().__init__(group, wrap_fsdp) self.group = group self.delay_before_free_ms = delay_before_free_ms self.wrap_fsdp = wrap_fsdp - + self.move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE # "expert" params are different on each rank torch.manual_seed(42 + group.rank()) d_expert = 23 d_shared = 12 d_input = 8 - expert = nn.Linear(d_expert, d_shared) + expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda) self.num_expert_params = sum([p.numel() for p in expert.parameters()]) for p in expert.parameters(): @@ -216,19 +254,22 @@ class MixtureOfExperts(NestedWrappedModule): # everything else is shared torch.manual_seed(0) - shared = nn.Linear(d_shared, d_expert) + shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda) if wrap_fsdp: # we create a process group of size 1 for the expert params expert_group = torch.distributed.new_group( [group.rank()] ) # world size 1 means no shard - expert = FullyShardedDataParallel(expert, expert_group) # type: ignore[assignment] + expert = FullyShardedDataParallel(expert, expert_group, **kwargs) # type: ignore[assignment] - shared = FullyShardedDataParallel(shared, group) # type: ignore[assignment] + shared = FullyShardedDataParallel(shared, group, **kwargs) # type: ignore[assignment] self.module = nn.Sequential( - nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input) + _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda), + shared, + expert, + _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda) ) def forward(self, x): @@ -318,7 +359,9 @@ class FSDPTest(MultiProcessTestCase): dist.destroy_process_group() sys.exit(0) - def _train_for_several_steps(self, model, num_steps, autocast, lr=0.01): + def _train_for_several_steps(self, model, num_steps, autocast, lr=0.01, fsdp_cpu_offload=None): + cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params + model_device = next(model.parameters()).device # use SGD with momentum instead of Adam, since Adam is scale invariant # and this makes it bad for tests @@ -329,19 +372,40 @@ class FSDPTest(MultiProcessTestCase): # Inputs always cuda regardless of cpu offloading, or model.device input = model.module.get_input(torch.device("cuda")) output = model(*input) + # Post-forward, if CPU offloading model param should be on CPU. + if cpu_offload_params and isinstance(model, FullyShardedDataParallel): + for p in model.parameters(): + # Params should always be on CPU, even if + # p._is_sharded=False + self.assertEqual(p.device, torch.device("cpu")) + loss = model.module.get_loss(input, output).to(model_device) assert ( loss.dtype == torch.float32 ), "loss data type should be float32, as the original \ parameter data type is float32." model.module.run_backward(loss) + # Post-backward, if CPU offloading model params should be on CPU. + if cpu_offload_params and isinstance(model, FullyShardedDataParallel): + for p in model.parameters(): + # Params should always be on CPU, even if + # p._is_sharded=False + self.assertEqual(p.device, torch.device("cpu")) optim.step() if isinstance(model, FullyShardedDataParallel): model._assert_state(TrainingState_.IDLE) return loss.detach() def _test_identical_outputs( - self, model_init_fn, ref_ddp_fn=None, num_steps=2, use_cuda=True, lr=0.01 + self, + model_init_fn, + *args, + ref_ddp_fn=None, + num_steps=2, + fsdp_init_mode=FSDPInitMode.CUDA_AFTER, + lr=0.01, + cpu_offload=CPUOffload(), + **kwargs ): group = dist.distributed_c10d._get_default_group() rank = group.rank() @@ -354,25 +418,60 @@ class FSDPTest(MultiProcessTestCase): else: model = ref_ddp_fn(model) ref_loss = self._train_for_several_steps( - model, num_steps, autocast=False, lr=lr + model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload ) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. - model = model_init_fn(group=group, wrap_fsdp=True) - model = FullyShardedDataParallel(model) - if use_cuda: + try: + model = model_init_fn(group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload) + except Exception as e: + raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}") + + cpu_offload = cpu_offload or CPUOffload() # disabled if not specified. + model = FullyShardedDataParallel(model, cpu_offload=cpu_offload) + # Call model.cuda() after init FSDP if specified. + if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: model = model.cuda() - else: - assert next(model.parameters()).device == torch.device( - "cpu" - ), "module parameters should be placed on cpu if use_cuda is False." - shard_loss = self._train_for_several_steps( - model, num_steps, autocast=False, lr=lr + + # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we + # expect FSDP code to raise error that we check below, in the case of + # offload params. + if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: + for p in model.parameters(): + # Should be on CPU regardless of if param is sharded. + self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}") + + only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params + ctx = ( + self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") + if only_check_err else suppress() ) + with ctx: + shard_loss = self._train_for_several_steps( + model, num_steps, autocast=False, lr=lr, + fsdp_cpu_offload=cpu_offload, + ) + # We only check for errors in the case we have the following setup: + # model = FSDP(model, cpu_offload=True) + # model = model.cuda() + # so skip the rest of this logic. + if only_check_err: + return + # If CPU offload, next call will change model params to GPU. Sanity + # check that params are on CPU before. + if cpu_offload.offload_params: + device_set = {p.device for p in model.parameters()} + self.assertEqual( + {torch.device("cpu")}, + device_set, + f"Got device set {device_set}" + ) get_full_params(model) shard_full_params = list(model.parameters()) + if cpu_offload.offload_params: + shard_loss = shard_loss.cuda() torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params,