mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
**Summary** This PR extend `_offload_state_dict_to_cpu` to accept a `cpu_offload_state_dict` argument. If `cpu_offload_state_dict` is not None, `_offload_state_dict_to_cpu` will use `copy_` to copy the GPU data to the CPU tensors. This allows users to pass a pin_memory or share_memory version of `cpu_offload_state_dict`. This PR also adds `_create_cpu_state_dict` to allow users to easily create a pin_memory or share_memory cpu state_dict. **Performance improvement** ``` # The micro-benchmark has a source state_dict with 150 tensors, and each tensor is 50MB. # The micro-benchmark is run on a H100 machine with PCIe 5 cpu_state_dict_2 = _create_cpu_state_dict(state_dict, pin_memory=True) cpu_state_dict_3 = _create_cpu_state_dict(state_dict, share_memory=True) # GPU->CPU memory: 4.6556 seconds cpu_state_dict = _offload_state_dict_to_cpu(state_dict) # GPU->pin memory: 0.1566 seconds _offload_state_dict_to_cpu(state_dict, cpu_offload_state_dict=cpu_state_dict_2) # GPU->shared memory: 0.5509 seconds (variation is quite large) _offload_state_dict_to_cpu(state_dict, cpu_offload_state_dict=cpu_state_dict_3) # GPU->pin memory->shared memory: 0.2550 seconds _offload_state_dict_to_cpu(state_dict, cpu_offload_state_dict=cpu_state_dict_2) _offload_state_dict_to_cpu(cpu_state_dict_2, cpu_offload_state_dict=cpu_state_dict_3) ``` Differential Revision: [D54045845](https://our.internmc.facebook.com/intern/diff/D54045845/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/120378 Approved by: https://github.com/LucasLLC
385 lines
13 KiB
Python
385 lines
13 KiB
Python
import io
|
|
import math
|
|
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
|
|
|
if dist.is_available() or TYPE_CHECKING:
|
|
from torch.distributed import distributed_c10d
|
|
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
|
from torch.distributed._tensor import DTensor, Replicate
|
|
|
|
|
|
def _identity_func(
|
|
obj: torch.Tensor,
|
|
pg: Optional[dist.ProcessGroup],
|
|
device: Optional[torch.device],
|
|
companion_obj: Any,
|
|
) -> torch.Tensor:
|
|
return obj
|
|
|
|
|
|
def _all_gather_sharded_tensor(
|
|
sharded_tensor: "ShardedTensor",
|
|
pg: Optional[dist.ProcessGroup] = None,
|
|
device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
if pg is None:
|
|
pg = distributed_c10d._get_default_group()
|
|
world_size = dist.get_world_size(pg)
|
|
shards = sharded_tensor.local_shards()
|
|
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
|
|
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
|
|
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
|
|
pg_device = (
|
|
distributed_c10d._get_pg_default_device(pg) if device is None else device
|
|
)
|
|
if shards:
|
|
local_tensor = shards[0].tensor.flatten()
|
|
if local_tensor.device.type != pg_device.type:
|
|
local_tensor = local_tensor.to(pg_device)
|
|
num_padding = chunk_size - local_tensor.numel()
|
|
if num_padding > 0:
|
|
local_tensor = F.pad(local_tensor, [0, num_padding])
|
|
else:
|
|
local_tensor = torch.zeros(
|
|
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
|
|
)
|
|
|
|
tensor = torch.empty(
|
|
chunk_size * world_size,
|
|
dtype=local_tensor.dtype,
|
|
device=pg_device,
|
|
)
|
|
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
|
|
|
|
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
|
|
return tensor
|
|
|
|
|
|
class CompanionMismatch(Exception):
|
|
...
|
|
|
|
|
|
def _iterate_state_dict(
|
|
iter_object: Any,
|
|
sharded_tensor_func: Callable,
|
|
dtensor_func: Callable,
|
|
tensor_func: Callable,
|
|
*,
|
|
pg: Optional[dist.ProcessGroup] = None,
|
|
device: Optional[torch.device] = None,
|
|
cpu_offload: bool = False,
|
|
companion_obj: Any = None,
|
|
ranks_only: Tuple[int, ...] = tuple(),
|
|
type_check: bool = True,
|
|
) -> Dict[str, Any]:
|
|
# TODO: should we use pytree?
|
|
cpu_device = torch.device("cpu")
|
|
if isinstance(iter_object, ShardedTensor):
|
|
ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
|
|
elif isinstance(iter_object, DTensor):
|
|
ret = dtensor_func(iter_object, pg, device, companion_obj)
|
|
elif isinstance(iter_object, torch.Tensor):
|
|
ret = tensor_func(iter_object, pg, device, companion_obj)
|
|
elif (
|
|
isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
|
|
or iter_object is None
|
|
):
|
|
ret = iter_object
|
|
elif isinstance(iter_object, dict):
|
|
if companion_obj is not None and (
|
|
not isinstance(companion_obj, dict)
|
|
or set(companion_obj.keys()) != set(iter_object.keys())
|
|
):
|
|
raise CompanionMismatch()
|
|
|
|
ret = {
|
|
key: _iterate_state_dict(
|
|
value,
|
|
sharded_tensor_func,
|
|
dtensor_func,
|
|
tensor_func,
|
|
pg=pg,
|
|
device=device,
|
|
cpu_offload=cpu_offload,
|
|
companion_obj=companion_obj[key] if companion_obj is not None else None,
|
|
ranks_only=ranks_only,
|
|
type_check=type_check,
|
|
)
|
|
for key, value in iter_object.items()
|
|
}
|
|
elif isinstance(iter_object, (list, tuple)):
|
|
if companion_obj is not None and (
|
|
not isinstance(companion_obj, (list, tuple))
|
|
or len(companion_obj) != len(iter_object)
|
|
):
|
|
raise CompanionMismatch()
|
|
|
|
ret = [
|
|
_iterate_state_dict(
|
|
v,
|
|
sharded_tensor_func,
|
|
dtensor_func,
|
|
tensor_func,
|
|
pg=pg,
|
|
device=device,
|
|
cpu_offload=cpu_offload,
|
|
companion_obj=companion_obj[idx] if companion_obj is not None else None,
|
|
ranks_only=ranks_only,
|
|
type_check=type_check,
|
|
)
|
|
for idx, v in enumerate(iter_object)
|
|
]
|
|
if isinstance(iter_object, tuple):
|
|
ret = tuple(ret)
|
|
elif not type_check:
|
|
ret = iter_object
|
|
else:
|
|
raise ValueError(f"Unexpected value type {type(iter_object)}")
|
|
|
|
if not ranks_only or dist.get_rank(pg) in ranks_only:
|
|
if isinstance(ret, torch.Tensor) and cpu_offload:
|
|
if companion_obj is None:
|
|
ret = ret.to(cpu_device)
|
|
else:
|
|
# TODO: support DTensor
|
|
companion_obj.copy_(ret, non_blocking=True)
|
|
else:
|
|
ret = {} if isinstance(ret, dict) else None
|
|
|
|
return ret
|
|
|
|
|
|
def _gather_state_dict(
|
|
state_dict: Dict[str, Any],
|
|
*,
|
|
pg: Optional[dist.ProcessGroup] = None,
|
|
device: Optional[torch.device] = None,
|
|
cpu_offload: bool = False,
|
|
ranks_only: Tuple[int, ...] = tuple(),
|
|
type_check: bool = True,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Given a state_dict, this API gathers all the ShardedTensors or DTensors in
|
|
the state_dict.
|
|
|
|
|
|
Args:
|
|
state_dict (Dict[str, Any]): the target sharded state_dict.
|
|
pg (Optional[dist.ProcessGroup]): the process group that is used to
|
|
gather ShardedTensor. Note that gathering a DTensor will use
|
|
the DeviceMesh. So this argument will be ignored when gathering a
|
|
DTensor.
|
|
device: (Optional[torch.device]): the device that is used to
|
|
perform allgather for ShardedTensor. Note that gathering a DTensor
|
|
will use the DeviceMesh. So this argument will be ignored when
|
|
gathering a DTensor.
|
|
cpu_offload (bool): whether to offload the tensors to CPU memory. The
|
|
default value is False.
|
|
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
|
|
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
|
|
have the same state_dicts. Other ranks will get empty state_dicts.
|
|
type_check: (bool): check if the instance data type is a supported type
|
|
that can be saved by DCP. The current supported data types are
|
|
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
|
|
|
Returns:
|
|
The gathered state dictionary.
|
|
"""
|
|
|
|
def sharded_tensor_func(value, pg, device, companion_obj):
|
|
# ShardedTensor does not seem to record the original device type.
|
|
# So if the tensor is moved to CPU, we won't know the original type.
|
|
# As a result, we have to rely on the user to tell us the correct one.
|
|
cpu_device = torch.device("cpu")
|
|
output_tensor = _all_gather_sharded_tensor(value, pg, device)
|
|
local_shard_device = (
|
|
value.local_shards()[0].tensor.device
|
|
if value.local_shards()
|
|
else cpu_device
|
|
)
|
|
if output_tensor.device != local_shard_device:
|
|
value = output_tensor.to(local_shard_device)
|
|
else:
|
|
value = output_tensor
|
|
return value
|
|
|
|
def dtensor_func(value, pg, device, companion_obj):
|
|
if value.device != value.device_mesh.device_type:
|
|
value = value.to(value.device_mesh.device_type)
|
|
# FSDP all_gather: [Shard(0)] -> [Replicate()]
|
|
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
|
|
# 2D FSDP + TP all_gather:
|
|
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
|
|
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
|
|
placements = [Replicate() for _ in value.placements]
|
|
value = value.redistribute(
|
|
device_mesh=value.device_mesh,
|
|
placements=placements,
|
|
)
|
|
# Call `wait()` to force the tensor to be synchronous with respect
|
|
# to the main stream.
|
|
# See the discussion in https://github.com/pytorch/pytorch/pull/117799.
|
|
value = value.to_local()
|
|
if isinstance(value, AsyncCollectiveTensor):
|
|
value = value.wait()
|
|
return value
|
|
|
|
return _iterate_state_dict(
|
|
state_dict,
|
|
sharded_tensor_func,
|
|
dtensor_func,
|
|
_identity_func,
|
|
pg=pg,
|
|
device=device,
|
|
cpu_offload=cpu_offload,
|
|
ranks_only=ranks_only,
|
|
type_check=type_check,
|
|
)
|
|
|
|
|
|
def _offload_state_dict_to_cpu(
|
|
state_dict: Dict[str, Any],
|
|
*,
|
|
ranks_only: Tuple[int, ...] = tuple(),
|
|
cpu_offload_state_dict: Optional[Dict[str, Any]] = None,
|
|
cpu_offload_sync: bool = True,
|
|
type_check: bool = True,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Given a state_dict, this API offload all the tensors to CPU memory.
|
|
|
|
Args:
|
|
state_dict (Dict[str, Any]): the target state_dict.
|
|
pg (Optional[dist.ProcessGroup]): the process group that is used to
|
|
gather ShardedTensor. Note that gathering a DTensor will use
|
|
the DeviceMesh. So this argument will be ignored when gathering a
|
|
DTensor.
|
|
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
|
|
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
|
|
have the same state_dicts. Other ranks will get empty state_dicts.
|
|
cpu_offload_state_dict (Optional[Dict[str, Any]]): the CPU state_dict
|
|
that will be returned. If this is not None, this API will use
|
|
`copy_` to copy the GPU tensor to the tensor in this CPU state_dict.
|
|
This CPU state_dict must have exactly the same structure as the
|
|
`state_dict` the only difference is that all the tensors in this
|
|
CPU state_dict are on CPU memory.
|
|
cpu_offload_sync: (bool): flag to decide whether to call `synchronize()`
|
|
before this API returns.
|
|
type_check: (bool): check if the instance data type is a supported type
|
|
that can be saved by DCP. The current supported data types are
|
|
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
|
|
|
Returns:
|
|
The gathered state dictionary.
|
|
"""
|
|
|
|
ret = _iterate_state_dict(
|
|
state_dict,
|
|
_identity_func,
|
|
_identity_func,
|
|
_identity_func,
|
|
pg=None,
|
|
device=None,
|
|
cpu_offload=True,
|
|
ranks_only=ranks_only,
|
|
companion_obj=cpu_offload_state_dict,
|
|
type_check=type_check,
|
|
)
|
|
if cpu_offload_state_dict is not None and cpu_offload_sync:
|
|
torch.cuda.synchronize()
|
|
return ret
|
|
|
|
|
|
def _create_cpu_state_dict(
|
|
state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Given a state_dict, create another state_dict with the same structure and elements.
|
|
However, all tensors in the returned state_dict are new tensors on CPU. These
|
|
tensors can be placed on pin_memory or share_memory based on the provided arguments.
|
|
"""
|
|
|
|
if pin_memory and share_memory:
|
|
raise ValueError(
|
|
"Cannot allocate both memory on both pin_memory and share_memory"
|
|
)
|
|
|
|
def tensor_func(
|
|
obj: torch.Tensor,
|
|
pg: Optional[dist.ProcessGroup],
|
|
device: Optional[torch.device],
|
|
companion_obj: Any,
|
|
) -> torch.Tensor:
|
|
if len(obj.size()) == 0:
|
|
return torch.tensor(0, dtype=obj.dtype)
|
|
|
|
if share_memory:
|
|
return torch.empty(
|
|
*tuple(companion_obj.size()), dtype=companion_obj.dtype
|
|
).share_memory_()
|
|
else:
|
|
return torch.empty(
|
|
*tuple(companion_obj.size()), dtype=companion_obj.dtype
|
|
).pin_memory()
|
|
|
|
ret = _iterate_state_dict(
|
|
state_dict,
|
|
_identity_func,
|
|
_identity_func,
|
|
tensor_func,
|
|
pg=None,
|
|
device=None,
|
|
cpu_offload=False,
|
|
ranks_only=tuple(),
|
|
companion_obj=state_dict,
|
|
type_check=False,
|
|
)
|
|
return ret
|
|
|
|
|
|
def _check_state_dict_similarity(
|
|
state_dict: Dict[str, Any],
|
|
compared_state_dict: Dict[str, Any],
|
|
) -> bool:
|
|
"""
|
|
Given two state_dicts, check if the structures are the same. And
|
|
if a [key, tensor] pair exist in one state_dict there must be
|
|
the a corresponding pait, [key, other_tensor], in the other state_dict,
|
|
where tensor and other_tensor have the same size and dtype.
|
|
|
|
Return the check result.
|
|
"""
|
|
|
|
def tensor_func(
|
|
obj: torch.Tensor,
|
|
pg: Optional[dist.ProcessGroup],
|
|
device: Optional[torch.device],
|
|
companion_obj: Any,
|
|
) -> torch.Tensor:
|
|
if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
|
|
raise CompanionMismatch()
|
|
return obj
|
|
|
|
try:
|
|
_iterate_state_dict(
|
|
state_dict,
|
|
_identity_func,
|
|
_identity_func,
|
|
tensor_func,
|
|
pg=None,
|
|
device=None,
|
|
cpu_offload=False,
|
|
ranks_only=tuple(),
|
|
companion_obj=compared_state_dict,
|
|
type_check=False,
|
|
)
|
|
except CompanionMismatch:
|
|
return False
|
|
|
|
return True
|