mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[FSDP] Only move current FSDP's states to GPU during init (#98319)
Fixes https://github.com/pytorch/pytorch/issues/95813 Pull Request resolved: https://github.com/pytorch/pytorch/pull/98319 Approved by: https://github.com/rohan-varma
This commit is contained in:
parent
d7156175fe
commit
66d07e3b19
|
|
@ -3,6 +3,7 @@ import warnings
|
|||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
|
|
@ -417,8 +418,6 @@ def _init_param_handle_from_module(
|
|||
fully_sharded_module,
|
||||
check_fn=lambda k: not isinstance(k, module_wrapper_cls),
|
||||
)
|
||||
# TODO: Investigate refactoring `_move_module_to_device()` to
|
||||
# `_move_states_to_device()` to avoid the `device_id` + CPU offload hack
|
||||
_move_module_to_device(
|
||||
fully_sharded_module, state._ignored_params, device_from_device_id
|
||||
)
|
||||
|
|
@ -803,20 +802,27 @@ def _move_module_to_device(
|
|||
if param is None:
|
||||
return # no original parameters to manage
|
||||
cpu_device = torch.device("cpu")
|
||||
# TODO: This only checks the parameter's device, not any buffers. Thus, a
|
||||
# buffer-only module will not get offloaded to CPU.
|
||||
if device_from_device_id is not None:
|
||||
if param.device == cpu_device:
|
||||
# NOTE: This includes moving ignored modules' parameters.
|
||||
module = module.to(device_from_device_id)
|
||||
# TODO: This is a temporary fix to move already-constructed
|
||||
# `FlatParameter`s back to CPU if needed. This is needed to
|
||||
# make CPU offload work with `device_id`.
|
||||
for submodule in module.modules():
|
||||
if (
|
||||
isinstance(submodule, fsdp_file.FullyShardedDataParallel)
|
||||
and submodule.cpu_offload.offload_params
|
||||
):
|
||||
for handle in submodule._handles:
|
||||
handle.flat_param_to(torch.device("cpu"))
|
||||
# BFS from `module` without traversing any nested FSDP instances to
|
||||
# collect the parameters/buffers that have not yet been managed
|
||||
queue: Deque[nn.Module] = collections.deque()
|
||||
queue.append(module)
|
||||
params: List[nn.Parameter] = []
|
||||
buffers: List[torch.Tensor] = []
|
||||
while queue:
|
||||
curr_module = queue.popleft()
|
||||
params.extend(curr_module.parameters(recurse=False))
|
||||
buffers.extend(curr_module.buffers(recurse=False))
|
||||
for submodule in curr_module.children():
|
||||
if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
|
||||
queue.append(submodule)
|
||||
# NOTE: This includes moving ignored modules' parameters. If we
|
||||
# decide to change the semantics in the future, simply filter based
|
||||
# on the ignored parameters (and buffers).
|
||||
_move_states_to_device(params, buffers, device_from_device_id)
|
||||
elif param.device == cpu_device:
|
||||
_warn_cpu_init()
|
||||
|
||||
|
|
@ -827,7 +833,8 @@ def _move_states_to_device(
|
|||
device_from_device_id: Optional[torch.device],
|
||||
) -> None:
|
||||
"""
|
||||
Precondition: ``_check_single_device_module()``.
|
||||
Precondition: ``_check_single_device_module()`` and module's parameters and
|
||||
buffers have been materialized if needed.
|
||||
"""
|
||||
if len(params) == 0 and len(buffers) == 0:
|
||||
return
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user