mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[FSDP] CPU offload resubmit (#67249)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67249 Implements CPU offload for model parameters in FSDP. - CPU offload class with only offload_params attribute is created - If this is specified in FSDP ctor, model parameters are moved back to CPU after sharding in __init__ - In forward pass, during lazy init, p._local_shard gets set to p.data so it is on CPU. We pin_memory here. - In forward pass, in _rebuild_full_params, we move p.data back to self.compute_device if necessary. Note that we don't use the device of p._full_param_padded because we don't always have this attr, but when we do its always the same as compute_device. - The same logic as above applies to the beginning of backwards pass. - At end of fwd and end of bwd, `_use_param_local_shard` takes care to ensure the parameters are offloaded to CPU again, by pointing it to p._local_shard, which is always on CPU. Regarding tests: - We tests 3 different types of init: 1) CUDA the model before wrapping with FSDP, 2) CUDA the model after wrapping with FSDP, 3) never CUDA the model. - Case 1 is always supported. Case 2 is not supported with CPU offload and throws an error during fwd pass. Case 3 is only supported with CPU offload at the moment. - Verifies all params are offloaded to CPU after init. - Verifies all params are offloaded to CPU after forward and backward. - Note that there is an issue with verifying exact parity when CPU offloading, but it appears to be related to transfering model back and forth cpu/CUDA. More details in https://github.com/pytorch/pytorch/pull/66961 ghstack-source-id: 141851903 Test Plan: CI Reviewed By: mrshenli Differential Revision: D31911085 fbshipit-source-id: 3ddf73c070b55ce383e62251868d609004fc30e7
This commit is contained in:
parent
06d1be2447
commit
7f3326a6d2
|
|
@ -12,6 +12,7 @@ from torch.testing._internal.common_distributed import (
|
|||
)
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
DummyDDP,
|
||||
FSDPInitMode,
|
||||
FSDPTest,
|
||||
MixtureOfExperts,
|
||||
NestedWrappedModule,
|
||||
|
|
@ -25,6 +26,8 @@ from torch.testing._internal.common_utils import (
|
|||
run_tests,
|
||||
)
|
||||
|
||||
from torch.distributed._fsdp.fully_sharded_data_parallel import CPUOffload
|
||||
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
|
|
@ -44,59 +47,146 @@ class TestParityWithDDP(FSDPTest):
|
|||
PyTorch DDP vs. FullyShardedDataParallel.
|
||||
"""
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_nested_wrapped_model(self):
|
||||
self._test_identical_outputs(NestedWrappedModule)
|
||||
def _get_init_modes_for_test(self, cpu_offload):
|
||||
modes = [
|
||||
FSDPInitMode.CUDA_AFTER,
|
||||
FSDPInitMode.CUDA_BEFORE
|
||||
]
|
||||
# Note that FSDPInitMode.CUDA_NEVER works currently only with CPU
|
||||
# offload as we explicitly bring the param back to CUDA device. In
|
||||
# general, it will not work since we try to all_gather p.data which is
|
||||
# on CPU but NCCL only supports GPU.
|
||||
if cpu_offload.offload_params:
|
||||
modes.append(FSDPInitMode.CUDA_NEVER)
|
||||
|
||||
return modes
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_nested_all_wrapped_model(self):
|
||||
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
|
||||
self._test_identical_outputs(model_fn)
|
||||
@parametrize(
|
||||
"cpu_offload",
|
||||
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
|
||||
)
|
||||
def test_nested_wrapped_model(self, cpu_offload):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
for fsdp_init_mode in init_modes:
|
||||
with self.subTest(fsdp_init_mode=fsdp_init_mode):
|
||||
self._test_identical_outputs(
|
||||
NestedWrappedModule,
|
||||
fsdp_init_mode=fsdp_init_mode,
|
||||
cpu_offload=cpu_offload
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_transformer_parameterized(self):
|
||||
self._test_identical_outputs(TransformerWithSharedParams)
|
||||
@parametrize(
|
||||
"cpu_offload",
|
||||
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
|
||||
)
|
||||
def test_nested_all_wrapped_model(self, cpu_offload):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
for fsdp_init_mode in init_modes:
|
||||
with self.subTest(fsdp_init_mode=fsdp_init_mode):
|
||||
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
|
||||
self._test_identical_outputs(
|
||||
model_fn,
|
||||
fsdp_init_mode=fsdp_init_mode,
|
||||
cpu_offload=cpu_offload
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_delayed_optim_step(self):
|
||||
@parametrize(
|
||||
"cpu_offload",
|
||||
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
|
||||
)
|
||||
def test_transformer_parameterized(self, cpu_offload):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
for fsdp_init_mode in init_modes:
|
||||
with self.subTest(fsdp_init_mode=fsdp_init_mode):
|
||||
self._test_identical_outputs(
|
||||
TransformerWithSharedParams,
|
||||
fsdp_init_mode=fsdp_init_mode,
|
||||
cpu_offload=cpu_offload
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize(
|
||||
"cpu_offload",
|
||||
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
|
||||
)
|
||||
def test_delayed_optim_step(self, cpu_offload):
|
||||
# We use a model with a long CUDA delay right before the optimizer step.
|
||||
# This tests our streams logic, and that we don't start the allgather
|
||||
# until after the optimization step completes.
|
||||
model_fn = functools.partial(
|
||||
NestedWrappedModuleWithDelay, delay_after_loss_ms=250
|
||||
)
|
||||
self._test_identical_outputs(model_fn)
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
for fsdp_init_mode in init_modes:
|
||||
with self.subTest(fsdp_init_mode=fsdp_init_mode):
|
||||
model_fn = functools.partial(
|
||||
NestedWrappedModuleWithDelay, delay_after_loss_ms=250
|
||||
)
|
||||
self._test_identical_outputs(
|
||||
model_fn,
|
||||
fsdp_init_mode=fsdp_init_mode,
|
||||
cpu_offload=cpu_offload
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_delayed_reduce_scatter(self):
|
||||
@parametrize(
|
||||
"cpu_offload",
|
||||
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
|
||||
)
|
||||
def test_delayed_reduce_scatter(self, cpu_offload):
|
||||
# We insert a delay in the torch.distributed._reduce_scatter_base op, so that
|
||||
# the post_backward_stream takes much longer than the backward pass.
|
||||
# This tests that we properly block at the end of the backward pass for
|
||||
# the reductions to finish.
|
||||
model_fn = functools.partial(
|
||||
NestedWrappedModuleWithDelay, delay_before_reduction_ms=250
|
||||
)
|
||||
self._test_identical_outputs(model_fn)
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
for fsdp_init_mode in init_modes:
|
||||
with self.subTest(fsdp_init_mode=fsdp_init_mode):
|
||||
model_fn = functools.partial(
|
||||
NestedWrappedModuleWithDelay, delay_before_reduction_ms=250
|
||||
)
|
||||
self._test_identical_outputs(
|
||||
model_fn,
|
||||
fsdp_init_mode=fsdp_init_mode,
|
||||
cpu_offload=cpu_offload
|
||||
)
|
||||
|
||||
def _dummy_ddp_fn(self, model):
|
||||
return DummyDDP(model)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_mixture_of_experts(self):
|
||||
self._test_identical_outputs(
|
||||
MixtureOfExperts,
|
||||
# MixtureOfExperts implements custom reduce logic, so the reference
|
||||
# behavior should use that logic instead of PyTorch DDP.
|
||||
ref_ddp_fn=self._dummy_ddp_fn,
|
||||
)
|
||||
@parametrize(
|
||||
"cpu_offload",
|
||||
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
|
||||
)
|
||||
def test_mixture_of_experts(self, cpu_offload):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
for fsdp_init_mode in init_modes:
|
||||
with self.subTest(fsdp_init_mode=fsdp_init_mode):
|
||||
self._test_identical_outputs(
|
||||
MixtureOfExperts,
|
||||
# MixtureOfExperts implements custom reduce logic, so the reference
|
||||
# behavior should use that logic instead of PyTorch DDP.
|
||||
ref_ddp_fn=self._dummy_ddp_fn,
|
||||
fsdp_init_mode=fsdp_init_mode,
|
||||
cpu_offload=cpu_offload,
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_mixture_of_experts_with_delay_before_free(self):
|
||||
model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250)
|
||||
self._test_identical_outputs(
|
||||
model_fn,
|
||||
ref_ddp_fn=self._dummy_ddp_fn,
|
||||
)
|
||||
@parametrize(
|
||||
"cpu_offload",
|
||||
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
|
||||
)
|
||||
def test_mixture_of_experts_with_delay_before_free(self, cpu_offload):
|
||||
init_modes = self._get_init_modes_for_test(cpu_offload)
|
||||
for fsdp_init_mode in init_modes:
|
||||
with self.subTest(fsdp_init_mode=fsdp_init_mode):
|
||||
model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250)
|
||||
self._test_identical_outputs(
|
||||
model_fn,
|
||||
ref_ddp_fn=self._dummy_ddp_fn,
|
||||
fsdp_init_mode=fsdp_init_mode,
|
||||
cpu_offload=cpu_offload,
|
||||
)
|
||||
|
||||
|
||||
class TestParamInit(FSDPTest):
|
||||
|
|
@ -194,6 +284,7 @@ class TestNoGrad(FSDPTest):
|
|||
|
||||
|
||||
instantiate_parametrized_tests(TestHooks)
|
||||
instantiate_parametrized_tests(TestParityWithDDP)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import functools
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
|
|
@ -30,6 +31,14 @@ if TYPE_CHECKING:
|
|||
from collections import OrderedDict # noqa: F401
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUOffload:
|
||||
offload_params: bool = False
|
||||
# TODO: state dict offloading, activation offloading
|
||||
# https://github.com/pytorch/pytorch/issues/67224
|
||||
|
||||
|
||||
class TrainingState_(Enum):
|
||||
"""
|
||||
Simple enum to indicate what state FSDP is in. Used for asserting
|
||||
|
|
@ -78,12 +87,20 @@ class FullyShardedDataParallel(nn.Module):
|
|||
module to be wrapped with FSDP.
|
||||
process_group (Optional[ProcessGroup]):
|
||||
process group for sharding
|
||||
cpu_offload (Optional [CPUOffload]):
|
||||
CPU offloading config. Currently, only parameter and grad CPU
|
||||
offload is supported. It can be enabled via passing in
|
||||
cpu_offload=CPUOffload(offload_params=True). Note that this
|
||||
currently implicitly enables gradient offloading to CPU in order for
|
||||
params and grads to be on same device to work with optimizer. This
|
||||
API is subject to change.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
cpu_offload: Optional[CPUOffload] = None
|
||||
):
|
||||
torch._C._log_api_usage_once("torch.distributed.fsdp")
|
||||
super().__init__()
|
||||
|
|
@ -107,6 +124,7 @@ class FullyShardedDataParallel(nn.Module):
|
|||
)
|
||||
|
||||
self.numel_padded_per_param: List[int] = []
|
||||
self.cpu_offload = cpu_offload or CPUOffload()
|
||||
|
||||
# Only handle params which are not already sharded. This enables
|
||||
# sharding individual layers of a Module, with an outer wrapper to
|
||||
|
|
@ -140,6 +158,11 @@ class FullyShardedDataParallel(nn.Module):
|
|||
# Flag to guard against preparing gradients multiple times per backward pass.
|
||||
self._pre_backward_hook_has_run = False
|
||||
|
||||
# If specified, offload parameter shard to CPU.
|
||||
if self.cpu_offload.offload_params:
|
||||
for p in self.params:
|
||||
self._offload_to_cpu(p)
|
||||
|
||||
@property
|
||||
def module(self) -> FlattenParamsWrapper:
|
||||
"""make model.module accessible, just like DDP."""
|
||||
|
|
@ -154,6 +177,17 @@ class FullyShardedDataParallel(nn.Module):
|
|||
factor *= 2
|
||||
return float(factor)
|
||||
|
||||
def _offload_to_cpu(self, p):
|
||||
"""
|
||||
Offloads parameter to CPU from self.compute_device. If the parameter is
|
||||
already on CPU then this is a noop.
|
||||
"""
|
||||
cpu_device = torch.device("cpu")
|
||||
if p.device == cpu_device:
|
||||
return
|
||||
with torch.no_grad():
|
||||
p.data = p.to(cpu_device)
|
||||
|
||||
def _cast_buffers(
|
||||
self, device: Optional[torch.device] = None, memo: Optional[Set] = None
|
||||
) -> None:
|
||||
|
|
@ -309,7 +343,9 @@ class FullyShardedDataParallel(nn.Module):
|
|||
# Don't free the full params for the outer-most (root) instance,
|
||||
# In most cases, root instance contains params in the last layers
|
||||
# or has no params. In these cases, those params will be needed
|
||||
# immediately after for the backward pass.
|
||||
# immediately after for the backward pass. Note that this only
|
||||
# applies currently when freeing parameters at end of layer's
|
||||
# forward pass.
|
||||
self.reshard_after_forward = False
|
||||
|
||||
# Due to the use of streams, we need to make sure the previous
|
||||
|
|
@ -343,10 +379,33 @@ class FullyShardedDataParallel(nn.Module):
|
|||
p, "_orig_size"
|
||||
), "Parameters should have been sharded during construction."
|
||||
if hasattr(p, "_local_shard"):
|
||||
# If CPU offloading, p._local_shard should have been placed on CPU
|
||||
# during its first lazy construction.
|
||||
if self.cpu_offload.offload_params:
|
||||
assert (
|
||||
p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined]
|
||||
), "Expected p._local_shard to be on CPU." # type: ignore[attr-defined]
|
||||
return
|
||||
|
||||
# A single shard of the parameters.
|
||||
# A single shard of the parameters. Also makes p._local_shard to be on
|
||||
# CPU if we are CPU offloading, since p.data would be on CPU during
|
||||
# init.
|
||||
if self.cpu_offload.offload_params:
|
||||
assert p.device == torch.device("cpu"), "Expected param to be on CPU when cpu_offloading is enabled."
|
||||
p._local_shard = p.data # type: ignore[attr-defined]
|
||||
# If CPU offloading, pin the memory to enable faster CPU -> GPU device
|
||||
# transfer.
|
||||
if self.cpu_offload.offload_params:
|
||||
assert p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined]
|
||||
p._local_shard.pin_memory() # type: ignore[attr-defined]
|
||||
# When offloading parameters, also move the grad shard to CPU during
|
||||
# backward pass. In this case, it's important to pre-allocate the
|
||||
# CPU grad shard in pinned memory so that we can do a non-blocking
|
||||
# transfer.
|
||||
p._cpu_grad = torch.zeros_like( # type: ignore[attr-defined]
|
||||
p,
|
||||
device=torch.device("cpu")
|
||||
).pin_memory()
|
||||
|
||||
# We also maintain a full-sized parameter of type self.compute_dtype.
|
||||
# We resize the storage to size 0 at init (here) and only materialize
|
||||
|
|
@ -422,7 +481,8 @@ class FullyShardedDataParallel(nn.Module):
|
|||
# Start of a forward pass.
|
||||
self.training_state = TrainingState_.FORWARD
|
||||
|
||||
# All-gather full parameters.
|
||||
# All-gather full parameters, moving them to compute_device if
|
||||
# necessary.
|
||||
self._rebuild_full_params()
|
||||
|
||||
# Register backward hooks to reshard params and reduce-scatter grads.
|
||||
|
|
@ -480,6 +540,8 @@ class FullyShardedDataParallel(nn.Module):
|
|||
if self._is_root:
|
||||
self._queue_wait_for_post_backward()
|
||||
|
||||
# All-gather full parameters, moving them to compute device if
|
||||
# necessary.
|
||||
self._rebuild_full_params()
|
||||
|
||||
self._pre_backward_hook_has_run = True
|
||||
|
|
@ -617,6 +679,23 @@ class FullyShardedDataParallel(nn.Module):
|
|||
), "Currently the only way for _is_sharded to be False is \
|
||||
world_size == 1"
|
||||
|
||||
# Regardless of sharding or not, offload the grad to CPU if we are
|
||||
# offloading params. This is so param and grad reside on same device
|
||||
# which is needed for the optimizer step.
|
||||
if self.cpu_offload.offload_params:
|
||||
# We specify non_blocking=True
|
||||
# and ensure the appropriate synchronization is done by waiting
|
||||
# streams in _wait_for_post_backward.
|
||||
param._cpu_grad.copy_( # type: ignore[attr-defined]
|
||||
param.grad.detach(), non_blocking=True
|
||||
)
|
||||
# Don't let this memory get reused until after the transfer.
|
||||
param.grad.data.record_stream(torch.cuda.current_stream())
|
||||
# Point param.grad.data to CPU grad to offload it. Note that
|
||||
# the transfer is async so it is not necessarily done until we
|
||||
# explicitly synchronize in backward.
|
||||
param.grad.data = param._cpu_grad # type: ignore[attr-defined]
|
||||
|
||||
# After _post_backward_hook returns, orig_grad_data will eventually
|
||||
# go out of scope, at which point it could otherwise be freed for
|
||||
# further reuse by the main stream while the div/reduce_scatter/copy
|
||||
|
|
@ -651,6 +730,13 @@ class FullyShardedDataParallel(nn.Module):
|
|||
self._assert_state(TrainingState_.BACKWARD_PRE)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
|
||||
if self.cpu_offload.offload_params:
|
||||
# We need to wait for the non-blocking GPU ->
|
||||
# CPU grad transfers to finish. TODO investigate why this is needed
|
||||
# and if we can remove it, as we've done transfer on post_backward
|
||||
# stream and synchronized it above.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# A backward pass is done, clean up below.
|
||||
|
||||
def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
|
||||
|
|
@ -712,6 +798,13 @@ class FullyShardedDataParallel(nn.Module):
|
|||
|
||||
with torch.cuda.stream(self._streams["all_gather"]):
|
||||
for p in self.params:
|
||||
if self.cpu_offload.offload_params:
|
||||
# Move params to GPU if needed. Note that we don't use
|
||||
# self._full_param_padded.device here because the attr is
|
||||
# not set always, i.e. when world_size=1 and
|
||||
# p._is_sharded = False. However when it is set, the
|
||||
# device is always self.compute_device.
|
||||
p.data = p.data.to(self.compute_device, non_blocking=True)
|
||||
# e.g., when world_size == 1
|
||||
if not p._is_sharded: # type: ignore[attr-defined]
|
||||
continue
|
||||
|
|
@ -724,7 +817,6 @@ class FullyShardedDataParallel(nn.Module):
|
|||
continue
|
||||
else:
|
||||
# If full param has not been rebuilt or has been freed, call all gather
|
||||
# Move params in CPU to CUDA for the all-gather.
|
||||
p_data = p.data # type: ignore[attr-defined]
|
||||
p_full_size = p._full_param_padded.size() # type: ignore[attr-defined]
|
||||
assert (
|
||||
|
|
@ -759,7 +851,9 @@ class FullyShardedDataParallel(nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
|
||||
"""Free up storage for full parameters."""
|
||||
"""
|
||||
Free up storage for full parameters.
|
||||
"""
|
||||
if params is None:
|
||||
params = self.params
|
||||
current_stream = torch.cuda.current_stream()
|
||||
|
|
@ -780,10 +874,16 @@ class FullyShardedDataParallel(nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def _use_param_local_shard(self, params: Optional[List[Parameter]] = None) -> None:
|
||||
"""Use local shard for a list of params."""
|
||||
"""Use local shard for a list of params. Also implicitly offloads
|
||||
parameters back to CPU if we are CPU offloading."""
|
||||
if params is None:
|
||||
params = self.params
|
||||
for p in params:
|
||||
if self.cpu_offload.offload_params:
|
||||
# Ensure local_shard resides in CPU if we are offloading params.
|
||||
assert (
|
||||
p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined]
|
||||
), "Expected p._local_shard to be on CPU" # type: ignore[attr-defined]
|
||||
p.data = p._local_shard # type: ignore[attr-defined]
|
||||
|
||||
def _assert_state(self, state: Union[TrainingState_, List[TrainingState_]]) -> None:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from contextlib import suppress
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
|
|
@ -7,7 +8,9 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._fsdp import FullyShardedDataParallel
|
||||
from torch.distributed._fsdp.fully_sharded_data_parallel import TrainingState_
|
||||
from torch.distributed._fsdp.fully_sharded_data_parallel import (
|
||||
TrainingState_, CPUOffload
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
TEST_SKIPS,
|
||||
|
|
@ -17,7 +20,20 @@ from torch.testing._internal.common_utils import (
|
|||
get_cycles_per_ms,
|
||||
)
|
||||
|
||||
# get full params of a model recursively
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FSDPInitMode(Enum):
|
||||
# Move model to CUDA before wrap
|
||||
CUDA_BEFORE = 1
|
||||
# Move model to CUDA after wrap
|
||||
CUDA_AFTER = 2
|
||||
# Don't move model to CUDA at all.
|
||||
CUDA_NEVER = 3
|
||||
|
||||
# get full params of a model recursively. Note that if CPU offloading, it will
|
||||
# also automatically move the parameters to GPU, due to _rebuild_full_params
|
||||
# call.
|
||||
def get_full_params(model, recurse=True):
|
||||
if recurse:
|
||||
# get all params for any nested FSDP instances.
|
||||
|
|
@ -30,6 +46,8 @@ def get_full_params(model, recurse=True):
|
|||
if model.module.flat_param is not None:
|
||||
model.module._unflatten_params()
|
||||
|
||||
def _maybe_cuda(model, move_to_cuda):
|
||||
return model.cuda() if move_to_cuda else model
|
||||
|
||||
class DummyProcessGroup:
|
||||
def __init__(self, rank: int, size: int):
|
||||
|
|
@ -45,7 +63,8 @@ class DummyProcessGroup:
|
|||
|
||||
class TransformerWithSharedParams(nn.Module):
|
||||
def __init__(
|
||||
self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs
|
||||
self, group, *args, d_vocab=23, d_model=16, add_bn=True,
|
||||
fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.rank = group.rank()
|
||||
|
|
@ -54,6 +73,7 @@ class TransformerWithSharedParams(nn.Module):
|
|||
assert (
|
||||
d_vocab >= 12
|
||||
), "dim of vocab should be larger than 12, as we use torch.arange(12) as input"
|
||||
|
||||
self.embed_tokens = nn.Embedding(d_vocab, d_model)
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=d_model,
|
||||
|
|
@ -73,6 +93,8 @@ class TransformerWithSharedParams(nn.Module):
|
|||
|
||||
self.bs = 2
|
||||
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
|
||||
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
||||
self = _maybe_cuda(self, move_to_cuda)
|
||||
|
||||
def get_input(self, device):
|
||||
torch.manual_seed(1 + self.rank) # keep everything deterministic
|
||||
|
|
@ -99,36 +121,37 @@ class TransformerWithSharedParams(nn.Module):
|
|||
|
||||
|
||||
class NestedWrappedModule(nn.Module):
|
||||
def __init__(self, group, wrap_fsdp, wrap_everything=False):
|
||||
def __init__(self, group, wrap_fsdp, *args, wrap_everything=False, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs):
|
||||
super().__init__()
|
||||
self.rank = group.rank()
|
||||
self.world_size = group.size()
|
||||
move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
||||
|
||||
def _maybe_wrap(layer):
|
||||
if wrap_fsdp:
|
||||
return FullyShardedDataParallel(layer, group)
|
||||
return FullyShardedDataParallel(layer, group, *args, **kwargs)
|
||||
return layer
|
||||
|
||||
torch.manual_seed(0) # keep everything deterministic
|
||||
|
||||
if wrap_everything:
|
||||
self.module = nn.Sequential(
|
||||
_maybe_wrap(nn.Linear(8, 4)),
|
||||
_maybe_wrap(nn.Linear(4, 16)),
|
||||
_maybe_wrap(nn.Linear(16, 4)),
|
||||
_maybe_wrap(nn.Linear(4, 8)),
|
||||
_maybe_wrap(_maybe_cuda(nn.Linear(8, 4), move_to_cuda)),
|
||||
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
|
||||
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
|
||||
_maybe_wrap(_maybe_cuda(nn.Linear(4, 8), move_to_cuda)),
|
||||
)
|
||||
else:
|
||||
self.module = nn.Sequential(
|
||||
nn.Linear(8, 4),
|
||||
_maybe_cuda(nn.Linear(8, 4), move_to_cuda),
|
||||
_maybe_wrap(
|
||||
nn.Sequential(
|
||||
_maybe_wrap(nn.Linear(4, 16)),
|
||||
nn.Linear(16, 16),
|
||||
)
|
||||
_maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
|
||||
_maybe_cuda(nn.Linear(16, 16), move_to_cuda),
|
||||
),
|
||||
),
|
||||
_maybe_wrap(nn.Linear(16, 4)),
|
||||
nn.Linear(4, 8),
|
||||
_maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
|
||||
_maybe_cuda(nn.Linear(4, 8), move_to_cuda),
|
||||
)
|
||||
|
||||
def get_input(self, device):
|
||||
|
|
@ -182,8 +205,23 @@ class ModuleWithDelay(nn.Module):
|
|||
|
||||
|
||||
class NestedWrappedModuleWithDelay(ModuleWithDelay):
|
||||
def __init__(self, group, wrap_fsdp, **kwargs):
|
||||
super().__init__(NestedWrappedModule(group, wrap_fsdp), **kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
group,
|
||||
wrap_fsdp,
|
||||
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
|
||||
cpu_offload=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
NestedWrappedModule(
|
||||
group,
|
||||
wrap_fsdp,
|
||||
fsdp_init_mode=fsdp_init_mode,
|
||||
cpu_offload=cpu_offload
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class DummyDDP(nn.Module):
|
||||
|
|
@ -196,18 +234,18 @@ class DummyDDP(nn.Module):
|
|||
|
||||
|
||||
class MixtureOfExperts(NestedWrappedModule):
|
||||
def __init__(self, group, wrap_fsdp, delay_before_free_ms=0):
|
||||
def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, **kwargs):
|
||||
super().__init__(group, wrap_fsdp)
|
||||
self.group = group
|
||||
self.delay_before_free_ms = delay_before_free_ms
|
||||
self.wrap_fsdp = wrap_fsdp
|
||||
|
||||
self.move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
|
||||
# "expert" params are different on each rank
|
||||
torch.manual_seed(42 + group.rank())
|
||||
d_expert = 23
|
||||
d_shared = 12
|
||||
d_input = 8
|
||||
expert = nn.Linear(d_expert, d_shared)
|
||||
expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)
|
||||
|
||||
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
|
||||
for p in expert.parameters():
|
||||
|
|
@ -216,19 +254,22 @@ class MixtureOfExperts(NestedWrappedModule):
|
|||
# everything else is shared
|
||||
torch.manual_seed(0)
|
||||
|
||||
shared = nn.Linear(d_shared, d_expert)
|
||||
shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)
|
||||
|
||||
if wrap_fsdp:
|
||||
# we create a process group of size 1 for the expert params
|
||||
expert_group = torch.distributed.new_group(
|
||||
[group.rank()]
|
||||
) # world size 1 means no shard
|
||||
expert = FullyShardedDataParallel(expert, expert_group) # type: ignore[assignment]
|
||||
expert = FullyShardedDataParallel(expert, expert_group, **kwargs) # type: ignore[assignment]
|
||||
|
||||
shared = FullyShardedDataParallel(shared, group) # type: ignore[assignment]
|
||||
shared = FullyShardedDataParallel(shared, group, **kwargs) # type: ignore[assignment]
|
||||
|
||||
self.module = nn.Sequential(
|
||||
nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input)
|
||||
_maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
|
||||
shared,
|
||||
expert,
|
||||
_maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -318,7 +359,9 @@ class FSDPTest(MultiProcessTestCase):
|
|||
dist.destroy_process_group()
|
||||
sys.exit(0)
|
||||
|
||||
def _train_for_several_steps(self, model, num_steps, autocast, lr=0.01):
|
||||
def _train_for_several_steps(self, model, num_steps, autocast, lr=0.01, fsdp_cpu_offload=None):
|
||||
cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
|
||||
|
||||
model_device = next(model.parameters()).device
|
||||
# use SGD with momentum instead of Adam, since Adam is scale invariant
|
||||
# and this makes it bad for tests
|
||||
|
|
@ -329,19 +372,40 @@ class FSDPTest(MultiProcessTestCase):
|
|||
# Inputs always cuda regardless of cpu offloading, or model.device
|
||||
input = model.module.get_input(torch.device("cuda"))
|
||||
output = model(*input)
|
||||
# Post-forward, if CPU offloading model param should be on CPU.
|
||||
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
|
||||
for p in model.parameters():
|
||||
# Params should always be on CPU, even if
|
||||
# p._is_sharded=False
|
||||
self.assertEqual(p.device, torch.device("cpu"))
|
||||
|
||||
loss = model.module.get_loss(input, output).to(model_device)
|
||||
assert (
|
||||
loss.dtype == torch.float32
|
||||
), "loss data type should be float32, as the original \
|
||||
parameter data type is float32."
|
||||
model.module.run_backward(loss)
|
||||
# Post-backward, if CPU offloading model params should be on CPU.
|
||||
if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
|
||||
for p in model.parameters():
|
||||
# Params should always be on CPU, even if
|
||||
# p._is_sharded=False
|
||||
self.assertEqual(p.device, torch.device("cpu"))
|
||||
optim.step()
|
||||
if isinstance(model, FullyShardedDataParallel):
|
||||
model._assert_state(TrainingState_.IDLE)
|
||||
return loss.detach()
|
||||
|
||||
def _test_identical_outputs(
|
||||
self, model_init_fn, ref_ddp_fn=None, num_steps=2, use_cuda=True, lr=0.01
|
||||
self,
|
||||
model_init_fn,
|
||||
*args,
|
||||
ref_ddp_fn=None,
|
||||
num_steps=2,
|
||||
fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
|
||||
lr=0.01,
|
||||
cpu_offload=CPUOffload(),
|
||||
**kwargs
|
||||
):
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
rank = group.rank()
|
||||
|
|
@ -354,25 +418,60 @@ class FSDPTest(MultiProcessTestCase):
|
|||
else:
|
||||
model = ref_ddp_fn(model)
|
||||
ref_loss = self._train_for_several_steps(
|
||||
model, num_steps, autocast=False, lr=lr
|
||||
model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload
|
||||
)
|
||||
ref_full_params = list(model.parameters())
|
||||
|
||||
# Confirm we get the same behavior using FullyShardedDataParallel.
|
||||
model = model_init_fn(group=group, wrap_fsdp=True)
|
||||
model = FullyShardedDataParallel(model)
|
||||
if use_cuda:
|
||||
try:
|
||||
model = model_init_fn(group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload)
|
||||
except Exception as e:
|
||||
raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")
|
||||
|
||||
cpu_offload = cpu_offload or CPUOffload() # disabled if not specified.
|
||||
model = FullyShardedDataParallel(model, cpu_offload=cpu_offload)
|
||||
# Call model.cuda() after init FSDP if specified.
|
||||
if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
|
||||
model = model.cuda()
|
||||
else:
|
||||
assert next(model.parameters()).device == torch.device(
|
||||
"cpu"
|
||||
), "module parameters should be placed on cpu if use_cuda is False."
|
||||
shard_loss = self._train_for_several_steps(
|
||||
model, num_steps, autocast=False, lr=lr
|
||||
|
||||
# Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
|
||||
# expect FSDP code to raise error that we check below, in the case of
|
||||
# offload params.
|
||||
if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
|
||||
for p in model.parameters():
|
||||
# Should be on CPU regardless of if param is sharded.
|
||||
self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")
|
||||
|
||||
only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
|
||||
ctx = (
|
||||
self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
|
||||
if only_check_err else suppress()
|
||||
)
|
||||
with ctx:
|
||||
shard_loss = self._train_for_several_steps(
|
||||
model, num_steps, autocast=False, lr=lr,
|
||||
fsdp_cpu_offload=cpu_offload,
|
||||
)
|
||||
# We only check for errors in the case we have the following setup:
|
||||
# model = FSDP(model, cpu_offload=True)
|
||||
# model = model.cuda()
|
||||
# so skip the rest of this logic.
|
||||
if only_check_err:
|
||||
return
|
||||
# If CPU offload, next call will change model params to GPU. Sanity
|
||||
# check that params are on CPU before.
|
||||
if cpu_offload.offload_params:
|
||||
device_set = {p.device for p in model.parameters()}
|
||||
self.assertEqual(
|
||||
{torch.device("cpu")},
|
||||
device_set,
|
||||
f"Got device set {device_set}"
|
||||
)
|
||||
get_full_params(model)
|
||||
shard_full_params = list(model.parameters())
|
||||
|
||||
if cpu_offload.offload_params:
|
||||
shard_loss = shard_loss.cuda()
|
||||
torch.testing.assert_allclose(ref_loss, shard_loss)
|
||||
self.assertEqual(
|
||||
ref_full_params,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user