[FSDP] CPU offload resubmit (#67249)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67249

Implements CPU offload for model parameters in FSDP.

- CPU offload class with only offload_params attribute is created
- If this is specified in FSDP ctor, model parameters are moved back to CPU after sharding in __init__
- In forward pass, during lazy init, p._local_shard gets set to p.data so it is on CPU. We pin_memory here.
- In forward pass, in _rebuild_full_params, we move p.data back to self.compute_device if necessary. Note that we don't use the device of p._full_param_padded because we don't always have this attr, but when we do its always the same as compute_device.
- The same logic as above applies to the beginning of backwards pass.
- At end of fwd and end of bwd, `_use_param_local_shard` takes care to ensure the parameters are offloaded to CPU again, by pointing it to p._local_shard, which is always on CPU.

Regarding tests:
- We tests 3 different types of init: 1) CUDA the model before wrapping with FSDP, 2) CUDA the model after wrapping with FSDP, 3) never CUDA the model.
- Case 1 is always supported. Case 2 is not supported with CPU offload and throws an error during fwd pass. Case 3 is only supported with CPU offload at the moment.
- Verifies all params are offloaded to CPU after init.
- Verifies all params are offloaded to CPU after forward and backward.
- Note that there is an issue with verifying exact parity when CPU offloading, but it appears to be related to transfering model back and forth cpu/CUDA. More details in https://github.com/pytorch/pytorch/pull/66961
ghstack-source-id: 141851903

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D31911085

fbshipit-source-id: 3ddf73c070b55ce383e62251868d609004fc30e7
This commit is contained in:
Rohan Varma 2021-11-02 23:25:54 -07:00 committed by Facebook GitHub Bot
parent 06d1be2447
commit 7f3326a6d2
3 changed files with 363 additions and 73 deletions

View File

@ -12,6 +12,7 @@ from torch.testing._internal.common_distributed import (
)
from torch.testing._internal.common_fsdp import (
DummyDDP,
FSDPInitMode,
FSDPTest,
MixtureOfExperts,
NestedWrappedModule,
@ -25,6 +26,8 @@ from torch.testing._internal.common_utils import (
run_tests,
)
from torch.distributed._fsdp.fully_sharded_data_parallel import CPUOffload
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
@ -44,59 +47,146 @@ class TestParityWithDDP(FSDPTest):
PyTorch DDP vs. FullyShardedDataParallel.
"""
@skip_if_lt_x_gpu(2)
def test_nested_wrapped_model(self):
self._test_identical_outputs(NestedWrappedModule)
def _get_init_modes_for_test(self, cpu_offload):
modes = [
FSDPInitMode.CUDA_AFTER,
FSDPInitMode.CUDA_BEFORE
]
# Note that FSDPInitMode.CUDA_NEVER works currently only with CPU
# offload as we explicitly bring the param back to CUDA device. In
# general, it will not work since we try to all_gather p.data which is
# on CPU but NCCL only supports GPU.
if cpu_offload.offload_params:
modes.append(FSDPInitMode.CUDA_NEVER)
return modes
@skip_if_lt_x_gpu(2)
def test_nested_all_wrapped_model(self):
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
self._test_identical_outputs(model_fn)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_nested_wrapped_model(self, cpu_offload):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
self._test_identical_outputs(
NestedWrappedModule,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
)
@skip_if_lt_x_gpu(2)
def test_transformer_parameterized(self):
self._test_identical_outputs(TransformerWithSharedParams)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_nested_all_wrapped_model(self, cpu_offload):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
self._test_identical_outputs(
model_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
)
@skip_if_lt_x_gpu(2)
def test_delayed_optim_step(self):
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_transformer_parameterized(self, cpu_offload):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
self._test_identical_outputs(
TransformerWithSharedParams,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
)
@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_delayed_optim_step(self, cpu_offload):
# We use a model with a long CUDA delay right before the optimizer step.
# This tests our streams logic, and that we don't start the allgather
# until after the optimization step completes.
model_fn = functools.partial(
NestedWrappedModuleWithDelay, delay_after_loss_ms=250
)
self._test_identical_outputs(model_fn)
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
model_fn = functools.partial(
NestedWrappedModuleWithDelay, delay_after_loss_ms=250
)
self._test_identical_outputs(
model_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
)
@skip_if_lt_x_gpu(2)
def test_delayed_reduce_scatter(self):
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_delayed_reduce_scatter(self, cpu_offload):
# We insert a delay in the torch.distributed._reduce_scatter_base op, so that
# the post_backward_stream takes much longer than the backward pass.
# This tests that we properly block at the end of the backward pass for
# the reductions to finish.
model_fn = functools.partial(
NestedWrappedModuleWithDelay, delay_before_reduction_ms=250
)
self._test_identical_outputs(model_fn)
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
model_fn = functools.partial(
NestedWrappedModuleWithDelay, delay_before_reduction_ms=250
)
self._test_identical_outputs(
model_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
)
def _dummy_ddp_fn(self, model):
return DummyDDP(model)
@skip_if_lt_x_gpu(2)
def test_mixture_of_experts(self):
self._test_identical_outputs(
MixtureOfExperts,
# MixtureOfExperts implements custom reduce logic, so the reference
# behavior should use that logic instead of PyTorch DDP.
ref_ddp_fn=self._dummy_ddp_fn,
)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_mixture_of_experts(self, cpu_offload):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
self._test_identical_outputs(
MixtureOfExperts,
# MixtureOfExperts implements custom reduce logic, so the reference
# behavior should use that logic instead of PyTorch DDP.
ref_ddp_fn=self._dummy_ddp_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload,
)
@skip_if_lt_x_gpu(2)
def test_mixture_of_experts_with_delay_before_free(self):
model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250)
self._test_identical_outputs(
model_fn,
ref_ddp_fn=self._dummy_ddp_fn,
)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_mixture_of_experts_with_delay_before_free(self, cpu_offload):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250)
self._test_identical_outputs(
model_fn,
ref_ddp_fn=self._dummy_ddp_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload,
)
class TestParamInit(FSDPTest):
@ -194,6 +284,7 @@ class TestNoGrad(FSDPTest):
instantiate_parametrized_tests(TestHooks)
instantiate_parametrized_tests(TestParityWithDDP)
if __name__ == "__main__":
run_tests()

View File

@ -1,5 +1,6 @@
import functools
import traceback
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
TYPE_CHECKING,
@ -30,6 +31,14 @@ if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
@dataclass
class CPUOffload:
offload_params: bool = False
# TODO: state dict offloading, activation offloading
# https://github.com/pytorch/pytorch/issues/67224
class TrainingState_(Enum):
"""
Simple enum to indicate what state FSDP is in. Used for asserting
@ -78,12 +87,20 @@ class FullyShardedDataParallel(nn.Module):
module to be wrapped with FSDP.
process_group (Optional[ProcessGroup]):
process group for sharding
cpu_offload (Optional [CPUOffload]):
CPU offloading config. Currently, only parameter and grad CPU
offload is supported. It can be enabled via passing in
cpu_offload=CPUOffload(offload_params=True). Note that this
currently implicitly enables gradient offloading to CPU in order for
params and grads to be on same device to work with optimizer. This
API is subject to change.
"""
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
cpu_offload: Optional[CPUOffload] = None
):
torch._C._log_api_usage_once("torch.distributed.fsdp")
super().__init__()
@ -107,6 +124,7 @@ class FullyShardedDataParallel(nn.Module):
)
self.numel_padded_per_param: List[int] = []
self.cpu_offload = cpu_offload or CPUOffload()
# Only handle params which are not already sharded. This enables
# sharding individual layers of a Module, with an outer wrapper to
@ -140,6 +158,11 @@ class FullyShardedDataParallel(nn.Module):
# Flag to guard against preparing gradients multiple times per backward pass.
self._pre_backward_hook_has_run = False
# If specified, offload parameter shard to CPU.
if self.cpu_offload.offload_params:
for p in self.params:
self._offload_to_cpu(p)
@property
def module(self) -> FlattenParamsWrapper:
"""make model.module accessible, just like DDP."""
@ -154,6 +177,17 @@ class FullyShardedDataParallel(nn.Module):
factor *= 2
return float(factor)
def _offload_to_cpu(self, p):
"""
Offloads parameter to CPU from self.compute_device. If the parameter is
already on CPU then this is a noop.
"""
cpu_device = torch.device("cpu")
if p.device == cpu_device:
return
with torch.no_grad():
p.data = p.to(cpu_device)
def _cast_buffers(
self, device: Optional[torch.device] = None, memo: Optional[Set] = None
) -> None:
@ -309,7 +343,9 @@ class FullyShardedDataParallel(nn.Module):
# Don't free the full params for the outer-most (root) instance,
# In most cases, root instance contains params in the last layers
# or has no params. In these cases, those params will be needed
# immediately after for the backward pass.
# immediately after for the backward pass. Note that this only
# applies currently when freeing parameters at end of layer's
# forward pass.
self.reshard_after_forward = False
# Due to the use of streams, we need to make sure the previous
@ -343,10 +379,33 @@ class FullyShardedDataParallel(nn.Module):
p, "_orig_size"
), "Parameters should have been sharded during construction."
if hasattr(p, "_local_shard"):
# If CPU offloading, p._local_shard should have been placed on CPU
# during its first lazy construction.
if self.cpu_offload.offload_params:
assert (
p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined]
), "Expected p._local_shard to be on CPU." # type: ignore[attr-defined]
return
# A single shard of the parameters.
# A single shard of the parameters. Also makes p._local_shard to be on
# CPU if we are CPU offloading, since p.data would be on CPU during
# init.
if self.cpu_offload.offload_params:
assert p.device == torch.device("cpu"), "Expected param to be on CPU when cpu_offloading is enabled."
p._local_shard = p.data # type: ignore[attr-defined]
# If CPU offloading, pin the memory to enable faster CPU -> GPU device
# transfer.
if self.cpu_offload.offload_params:
assert p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined]
p._local_shard.pin_memory() # type: ignore[attr-defined]
# When offloading parameters, also move the grad shard to CPU during
# backward pass. In this case, it's important to pre-allocate the
# CPU grad shard in pinned memory so that we can do a non-blocking
# transfer.
p._cpu_grad = torch.zeros_like( # type: ignore[attr-defined]
p,
device=torch.device("cpu")
).pin_memory()
# We also maintain a full-sized parameter of type self.compute_dtype.
# We resize the storage to size 0 at init (here) and only materialize
@ -422,7 +481,8 @@ class FullyShardedDataParallel(nn.Module):
# Start of a forward pass.
self.training_state = TrainingState_.FORWARD
# All-gather full parameters.
# All-gather full parameters, moving them to compute_device if
# necessary.
self._rebuild_full_params()
# Register backward hooks to reshard params and reduce-scatter grads.
@ -480,6 +540,8 @@ class FullyShardedDataParallel(nn.Module):
if self._is_root:
self._queue_wait_for_post_backward()
# All-gather full parameters, moving them to compute device if
# necessary.
self._rebuild_full_params()
self._pre_backward_hook_has_run = True
@ -617,6 +679,23 @@ class FullyShardedDataParallel(nn.Module):
), "Currently the only way for _is_sharded to be False is \
world_size == 1"
# Regardless of sharding or not, offload the grad to CPU if we are
# offloading params. This is so param and grad reside on same device
# which is needed for the optimizer step.
if self.cpu_offload.offload_params:
# We specify non_blocking=True
# and ensure the appropriate synchronization is done by waiting
# streams in _wait_for_post_backward.
param._cpu_grad.copy_( # type: ignore[attr-defined]
param.grad.detach(), non_blocking=True
)
# Don't let this memory get reused until after the transfer.
param.grad.data.record_stream(torch.cuda.current_stream())
# Point param.grad.data to CPU grad to offload it. Note that
# the transfer is async so it is not necessarily done until we
# explicitly synchronize in backward.
param.grad.data = param._cpu_grad # type: ignore[attr-defined]
# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
# further reuse by the main stream while the div/reduce_scatter/copy
@ -651,6 +730,13 @@ class FullyShardedDataParallel(nn.Module):
self._assert_state(TrainingState_.BACKWARD_PRE)
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self.cpu_offload.offload_params:
# We need to wait for the non-blocking GPU ->
# CPU grad transfers to finish. TODO investigate why this is needed
# and if we can remove it, as we've done transfer on post_backward
# stream and synchronized it above.
torch.cuda.current_stream().synchronize()
# A backward pass is done, clean up below.
def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
@ -712,6 +798,13 @@ class FullyShardedDataParallel(nn.Module):
with torch.cuda.stream(self._streams["all_gather"]):
for p in self.params:
if self.cpu_offload.offload_params:
# Move params to GPU if needed. Note that we don't use
# self._full_param_padded.device here because the attr is
# not set always, i.e. when world_size=1 and
# p._is_sharded = False. However when it is set, the
# device is always self.compute_device.
p.data = p.data.to(self.compute_device, non_blocking=True)
# e.g., when world_size == 1
if not p._is_sharded: # type: ignore[attr-defined]
continue
@ -724,7 +817,6 @@ class FullyShardedDataParallel(nn.Module):
continue
else:
# If full param has not been rebuilt or has been freed, call all gather
# Move params in CPU to CUDA for the all-gather.
p_data = p.data # type: ignore[attr-defined]
p_full_size = p._full_param_padded.size() # type: ignore[attr-defined]
assert (
@ -759,7 +851,9 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad()
def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
"""Free up storage for full parameters."""
"""
Free up storage for full parameters.
"""
if params is None:
params = self.params
current_stream = torch.cuda.current_stream()
@ -780,10 +874,16 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad()
def _use_param_local_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use local shard for a list of params."""
"""Use local shard for a list of params. Also implicitly offloads
parameters back to CPU if we are CPU offloading."""
if params is None:
params = self.params
for p in params:
if self.cpu_offload.offload_params:
# Ensure local_shard resides in CPU if we are offloading params.
assert (
p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined]
), "Expected p._local_shard to be on CPU" # type: ignore[attr-defined]
p.data = p._local_shard # type: ignore[attr-defined]
def _assert_state(self, state: Union[TrainingState_, List[TrainingState_]]) -> None:

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: distributed"]
from contextlib import suppress
import sys
from unittest import mock
@ -7,7 +8,9 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._fsdp import FullyShardedDataParallel
from torch.distributed._fsdp.fully_sharded_data_parallel import TrainingState_
from torch.distributed._fsdp.fully_sharded_data_parallel import (
TrainingState_, CPUOffload
)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
TEST_SKIPS,
@ -17,7 +20,20 @@ from torch.testing._internal.common_utils import (
get_cycles_per_ms,
)
# get full params of a model recursively
from enum import Enum
class FSDPInitMode(Enum):
# Move model to CUDA before wrap
CUDA_BEFORE = 1
# Move model to CUDA after wrap
CUDA_AFTER = 2
# Don't move model to CUDA at all.
CUDA_NEVER = 3
# get full params of a model recursively. Note that if CPU offloading, it will
# also automatically move the parameters to GPU, due to _rebuild_full_params
# call.
def get_full_params(model, recurse=True):
if recurse:
# get all params for any nested FSDP instances.
@ -30,6 +46,8 @@ def get_full_params(model, recurse=True):
if model.module.flat_param is not None:
model.module._unflatten_params()
def _maybe_cuda(model, move_to_cuda):
return model.cuda() if move_to_cuda else model
class DummyProcessGroup:
def __init__(self, rank: int, size: int):
@ -45,7 +63,8 @@ class DummyProcessGroup:
class TransformerWithSharedParams(nn.Module):
def __init__(
self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs
self, group, *args, d_vocab=23, d_model=16, add_bn=True,
fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs
):
super().__init__()
self.rank = group.rank()
@ -54,6 +73,7 @@ class TransformerWithSharedParams(nn.Module):
assert (
d_vocab >= 12
), "dim of vocab should be larger than 12, as we use torch.arange(12) as input"
self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer(
d_model=d_model,
@ -73,6 +93,8 @@ class TransformerWithSharedParams(nn.Module):
self.bs = 2
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
self = _maybe_cuda(self, move_to_cuda)
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
@ -99,36 +121,37 @@ class TransformerWithSharedParams(nn.Module):
class NestedWrappedModule(nn.Module):
def __init__(self, group, wrap_fsdp, wrap_everything=False):
def __init__(self, group, wrap_fsdp, *args, wrap_everything=False, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
def _maybe_wrap(layer):
if wrap_fsdp:
return FullyShardedDataParallel(layer, group)
return FullyShardedDataParallel(layer, group, *args, **kwargs)
return layer
torch.manual_seed(0) # keep everything deterministic
if wrap_everything:
self.module = nn.Sequential(
_maybe_wrap(nn.Linear(8, 4)),
_maybe_wrap(nn.Linear(4, 16)),
_maybe_wrap(nn.Linear(16, 4)),
_maybe_wrap(nn.Linear(4, 8)),
_maybe_wrap(_maybe_cuda(nn.Linear(8, 4), move_to_cuda)),
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
_maybe_wrap(_maybe_cuda(nn.Linear(4, 8), move_to_cuda)),
)
else:
self.module = nn.Sequential(
nn.Linear(8, 4),
_maybe_cuda(nn.Linear(8, 4), move_to_cuda),
_maybe_wrap(
nn.Sequential(
_maybe_wrap(nn.Linear(4, 16)),
nn.Linear(16, 16),
)
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
_maybe_cuda(nn.Linear(16, 16), move_to_cuda),
),
),
_maybe_wrap(nn.Linear(16, 4)),
nn.Linear(4, 8),
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
_maybe_cuda(nn.Linear(4, 8), move_to_cuda),
)
def get_input(self, device):
@ -182,8 +205,23 @@ class ModuleWithDelay(nn.Module):
class NestedWrappedModuleWithDelay(ModuleWithDelay):
def __init__(self, group, wrap_fsdp, **kwargs):
super().__init__(NestedWrappedModule(group, wrap_fsdp), **kwargs)
def __init__(
self,
group,
wrap_fsdp,
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
cpu_offload=None,
**kwargs
):
super().__init__(
NestedWrappedModule(
group,
wrap_fsdp,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
),
**kwargs
)
class DummyDDP(nn.Module):
@ -196,18 +234,18 @@ class DummyDDP(nn.Module):
class MixtureOfExperts(NestedWrappedModule):
def __init__(self, group, wrap_fsdp, delay_before_free_ms=0):
def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, **kwargs):
super().__init__(group, wrap_fsdp)
self.group = group
self.delay_before_free_ms = delay_before_free_ms
self.wrap_fsdp = wrap_fsdp
self.move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
d_expert = 23
d_shared = 12
d_input = 8
expert = nn.Linear(d_expert, d_shared)
expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
@ -216,19 +254,22 @@ class MixtureOfExperts(NestedWrappedModule):
# everything else is shared
torch.manual_seed(0)
shared = nn.Linear(d_shared, d_expert)
shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)
if wrap_fsdp:
# we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group(
[group.rank()]
) # world size 1 means no shard
expert = FullyShardedDataParallel(expert, expert_group) # type: ignore[assignment]
expert = FullyShardedDataParallel(expert, expert_group, **kwargs) # type: ignore[assignment]
shared = FullyShardedDataParallel(shared, group) # type: ignore[assignment]
shared = FullyShardedDataParallel(shared, group, **kwargs) # type: ignore[assignment]
self.module = nn.Sequential(
nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input)
_maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
shared,
expert,
_maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda)
)
def forward(self, x):
@ -318,7 +359,9 @@ class FSDPTest(MultiProcessTestCase):
dist.destroy_process_group()
sys.exit(0)
def _train_for_several_steps(self, model, num_steps, autocast, lr=0.01):
def _train_for_several_steps(self, model, num_steps, autocast, lr=0.01, fsdp_cpu_offload=None):
cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
model_device = next(model.parameters()).device
# use SGD with momentum instead of Adam, since Adam is scale invariant
# and this makes it bad for tests
@ -329,19 +372,40 @@ class FSDPTest(MultiProcessTestCase):
# Inputs always cuda regardless of cpu offloading, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
# Post-forward, if CPU offloading model param should be on CPU.
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
for p in model.parameters():
# Params should always be on CPU, even if
# p._is_sharded=False
self.assertEqual(p.device, torch.device("cpu"))
loss = model.module.get_loss(input, output).to(model_device)
assert (
loss.dtype == torch.float32
), "loss data type should be float32, as the original \
parameter data type is float32."
model.module.run_backward(loss)
# Post-backward, if CPU offloading model params should be on CPU.
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
for p in model.parameters():
# Params should always be on CPU, even if
# p._is_sharded=False
self.assertEqual(p.device, torch.device("cpu"))
optim.step()
if isinstance(model, FullyShardedDataParallel):
model._assert_state(TrainingState_.IDLE)
return loss.detach()
def _test_identical_outputs(
self, model_init_fn, ref_ddp_fn=None, num_steps=2, use_cuda=True, lr=0.01
self,
model_init_fn,
*args,
ref_ddp_fn=None,
num_steps=2,
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
lr=0.01,
cpu_offload=CPUOffload(),
**kwargs
):
group = dist.distributed_c10d._get_default_group()
rank = group.rank()
@ -354,25 +418,60 @@ class FSDPTest(MultiProcessTestCase):
else:
model = ref_ddp_fn(model)
ref_loss = self._train_for_several_steps(
model, num_steps, autocast=False, lr=lr
model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload
)
ref_full_params = list(model.parameters())
# Confirm we get the same behavior using FullyShardedDataParallel.
model = model_init_fn(group=group, wrap_fsdp=True)
model = FullyShardedDataParallel(model)
if use_cuda:
try:
model = model_init_fn(group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload)
except Exception as e:
raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")
cpu_offload = cpu_offload or CPUOffload() # disabled if not specified.
model = FullyShardedDataParallel(model, cpu_offload=cpu_offload)
# Call model.cuda() after init FSDP if specified.
if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device(
"cpu"
), "module parameters should be placed on cpu if use_cuda is False."
shard_loss = self._train_for_several_steps(
model, num_steps, autocast=False, lr=lr
# Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
# expect FSDP code to raise error that we check below, in the case of
# offload params.
if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
for p in model.parameters():
# Should be on CPU regardless of if param is sharded.
self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")
only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
ctx = (
self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
if only_check_err else suppress()
)
with ctx:
shard_loss = self._train_for_several_steps(
model, num_steps, autocast=False, lr=lr,
fsdp_cpu_offload=cpu_offload,
)
# We only check for errors in the case we have the following setup:
# model = FSDP(model, cpu_offload=True)
# model = model.cuda()
# so skip the rest of this logic.
if only_check_err:
return
# If CPU offload, next call will change model params to GPU. Sanity
# check that params are on CPU before.
if cpu_offload.offload_params:
device_set = {p.device for p in model.parameters()}
self.assertEqual(
{torch.device("cpu")},
device_set,
f"Got device set {device_set}"
)
get_full_params(model)
shard_full_params = list(model.parameters())
if cpu_offload.offload_params:
shard_loss = shard_loss.cuda()
torch.testing.assert_allclose(ref_loss, shard_loss)
self.assertEqual(
ref_full_params,