[FSDP] Only move current FSDP's states to GPU during init (#98319)

Fixes https://github.com/pytorch/pytorch/issues/95813
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98319
Approved by: https://github.com/rohan-varma
This commit is contained in:
Andrew Gu 2023-04-04 16:54:50 +00:00 committed by PyTorch MergeBot
parent d7156175fe
commit 66d07e3b19

View File

@ -3,6 +3,7 @@ import warnings
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
Iterable,
@ -417,8 +418,6 @@ def _init_param_handle_from_module(
fully_sharded_module,
check_fn=lambda k: not isinstance(k, module_wrapper_cls),
)
# TODO: Investigate refactoring `_move_module_to_device()` to
# `_move_states_to_device()` to avoid the `device_id` + CPU offload hack
_move_module_to_device(
fully_sharded_module, state._ignored_params, device_from_device_id
)
@ -803,20 +802,27 @@ def _move_module_to_device(
if param is None:
return # no original parameters to manage
cpu_device = torch.device("cpu")
# TODO: This only checks the parameter's device, not any buffers. Thus, a
# buffer-only module will not get offloaded to CPU.
if device_from_device_id is not None:
if param.device == cpu_device:
# NOTE: This includes moving ignored modules' parameters.
module = module.to(device_from_device_id)
# TODO: This is a temporary fix to move already-constructed
# `FlatParameter`s back to CPU if needed. This is needed to
# make CPU offload work with `device_id`.
for submodule in module.modules():
if (
isinstance(submodule, fsdp_file.FullyShardedDataParallel)
and submodule.cpu_offload.offload_params
):
for handle in submodule._handles:
handle.flat_param_to(torch.device("cpu"))
# BFS from `module` without traversing any nested FSDP instances to
# collect the parameters/buffers that have not yet been managed
queue: Deque[nn.Module] = collections.deque()
queue.append(module)
params: List[nn.Parameter] = []
buffers: List[torch.Tensor] = []
while queue:
curr_module = queue.popleft()
params.extend(curr_module.parameters(recurse=False))
buffers.extend(curr_module.buffers(recurse=False))
for submodule in curr_module.children():
if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
queue.append(submodule)
# NOTE: This includes moving ignored modules' parameters. If we
# decide to change the semantics in the future, simply filter based
# on the ignored parameters (and buffers).
_move_states_to_device(params, buffers, device_from_device_id)
elif param.device == cpu_device:
_warn_cpu_init()
@ -827,7 +833,8 @@ def _move_states_to_device(
device_from_device_id: Optional[torch.device],
) -> None:
"""
Precondition: ``_check_single_device_module()``.
Precondition: ``_check_single_device_module()`` and module's parameters and
buffers have been materialized if needed.
"""
if len(params) == 0 and len(buffers) == 0:
return