pytorch/torch/distributed/_state_dict_utils.py
Chien-Chin Huang 5abf7972d1 [DCP][state_dict] Implement pin_memory and shared_memory copy for _offload_state_dict_to_cpu (#120378)
**Summary**
This PR extend `_offload_state_dict_to_cpu` to accept a `cpu_offload_state_dict` argument. If `cpu_offload_state_dict` is not None, `_offload_state_dict_to_cpu` will use `copy_` to copy the GPU data to the CPU tensors. This allows users to pass a pin_memory or share_memory version of `cpu_offload_state_dict`.

This PR also adds `_create_cpu_state_dict` to allow users to easily create a pin_memory or share_memory cpu state_dict.

**Performance improvement**
```
# The micro-benchmark has a source state_dict with 150 tensors, and each tensor is 50MB.
# The micro-benchmark is run on a H100 machine with PCIe 5

cpu_state_dict_2 = _create_cpu_state_dict(state_dict, pin_memory=True)
cpu_state_dict_3 = _create_cpu_state_dict(state_dict, share_memory=True)

# GPU->CPU memory: 4.6556 seconds
cpu_state_dict = _offload_state_dict_to_cpu(state_dict)

# GPU->pin memory: 0.1566 seconds
_offload_state_dict_to_cpu(state_dict, cpu_offload_state_dict=cpu_state_dict_2)

# GPU->shared memory: 0.5509 seconds (variation is quite large)
_offload_state_dict_to_cpu(state_dict, cpu_offload_state_dict=cpu_state_dict_3)

# GPU->pin memory->shared memory: 0.2550 seconds
_offload_state_dict_to_cpu(state_dict, cpu_offload_state_dict=cpu_state_dict_2)
_offload_state_dict_to_cpu(cpu_state_dict_2, cpu_offload_state_dict=cpu_state_dict_3)
```

Differential Revision: [D54045845](https://our.internmc.facebook.com/intern/diff/D54045845/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120378
Approved by: https://github.com/LucasLLC
2024-03-05 17:48:15 +00:00

385 lines
13 KiB
Python

import io
import math
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
if dist.is_available() or TYPE_CHECKING:
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Replicate
def _identity_func(
obj: torch.Tensor,
pg: Optional[dist.ProcessGroup],
device: Optional[torch.device],
companion_obj: Any,
) -> torch.Tensor:
return obj
def _all_gather_sharded_tensor(
sharded_tensor: "ShardedTensor",
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if pg is None:
pg = distributed_c10d._get_default_group()
world_size = dist.get_world_size(pg)
shards = sharded_tensor.local_shards()
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
pg_device = (
distributed_c10d._get_pg_default_device(pg) if device is None else device
)
if shards:
local_tensor = shards[0].tensor.flatten()
if local_tensor.device.type != pg_device.type:
local_tensor = local_tensor.to(pg_device)
num_padding = chunk_size - local_tensor.numel()
if num_padding > 0:
local_tensor = F.pad(local_tensor, [0, num_padding])
else:
local_tensor = torch.zeros(
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
)
tensor = torch.empty(
chunk_size * world_size,
dtype=local_tensor.dtype,
device=pg_device,
)
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
return tensor
class CompanionMismatch(Exception):
...
def _iterate_state_dict(
iter_object: Any,
sharded_tensor_func: Callable,
dtensor_func: Callable,
tensor_func: Callable,
*,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
cpu_offload: bool = False,
companion_obj: Any = None,
ranks_only: Tuple[int, ...] = tuple(),
type_check: bool = True,
) -> Dict[str, Any]:
# TODO: should we use pytree?
cpu_device = torch.device("cpu")
if isinstance(iter_object, ShardedTensor):
ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
elif isinstance(iter_object, DTensor):
ret = dtensor_func(iter_object, pg, device, companion_obj)
elif isinstance(iter_object, torch.Tensor):
ret = tensor_func(iter_object, pg, device, companion_obj)
elif (
isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
or iter_object is None
):
ret = iter_object
elif isinstance(iter_object, dict):
if companion_obj is not None and (
not isinstance(companion_obj, dict)
or set(companion_obj.keys()) != set(iter_object.keys())
):
raise CompanionMismatch()
ret = {
key: _iterate_state_dict(
value,
sharded_tensor_func,
dtensor_func,
tensor_func,
pg=pg,
device=device,
cpu_offload=cpu_offload,
companion_obj=companion_obj[key] if companion_obj is not None else None,
ranks_only=ranks_only,
type_check=type_check,
)
for key, value in iter_object.items()
}
elif isinstance(iter_object, (list, tuple)):
if companion_obj is not None and (
not isinstance(companion_obj, (list, tuple))
or len(companion_obj) != len(iter_object)
):
raise CompanionMismatch()
ret = [
_iterate_state_dict(
v,
sharded_tensor_func,
dtensor_func,
tensor_func,
pg=pg,
device=device,
cpu_offload=cpu_offload,
companion_obj=companion_obj[idx] if companion_obj is not None else None,
ranks_only=ranks_only,
type_check=type_check,
)
for idx, v in enumerate(iter_object)
]
if isinstance(iter_object, tuple):
ret = tuple(ret)
elif not type_check:
ret = iter_object
else:
raise ValueError(f"Unexpected value type {type(iter_object)}")
if not ranks_only or dist.get_rank(pg) in ranks_only:
if isinstance(ret, torch.Tensor) and cpu_offload:
if companion_obj is None:
ret = ret.to(cpu_device)
else:
# TODO: support DTensor
companion_obj.copy_(ret, non_blocking=True)
else:
ret = {} if isinstance(ret, dict) else None
return ret
def _gather_state_dict(
state_dict: Dict[str, Any],
*,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
cpu_offload: bool = False,
ranks_only: Tuple[int, ...] = tuple(),
type_check: bool = True,
) -> Dict[str, Any]:
"""
Given a state_dict, this API gathers all the ShardedTensors or DTensors in
the state_dict.
Args:
state_dict (Dict[str, Any]): the target sharded state_dict.
pg (Optional[dist.ProcessGroup]): the process group that is used to
gather ShardedTensor. Note that gathering a DTensor will use
the DeviceMesh. So this argument will be ignored when gathering a
DTensor.
device: (Optional[torch.device]): the device that is used to
perform allgather for ShardedTensor. Note that gathering a DTensor
will use the DeviceMesh. So this argument will be ignored when
gathering a DTensor.
cpu_offload (bool): whether to offload the tensors to CPU memory. The
default value is False.
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
have the same state_dicts. Other ranks will get empty state_dicts.
type_check: (bool): check if the instance data type is a supported type
that can be saved by DCP. The current supported data types are
torch.Tensor, DTensor, int, float, str, list, dict, None.
Returns:
The gathered state dictionary.
"""
def sharded_tensor_func(value, pg, device, companion_obj):
# ShardedTensor does not seem to record the original device type.
# So if the tensor is moved to CPU, we won't know the original type.
# As a result, we have to rely on the user to tell us the correct one.
cpu_device = torch.device("cpu")
output_tensor = _all_gather_sharded_tensor(value, pg, device)
local_shard_device = (
value.local_shards()[0].tensor.device
if value.local_shards()
else cpu_device
)
if output_tensor.device != local_shard_device:
value = output_tensor.to(local_shard_device)
else:
value = output_tensor
return value
def dtensor_func(value, pg, device, companion_obj):
if value.device != value.device_mesh.device_type:
value = value.to(value.device_mesh.device_type)
# FSDP all_gather: [Shard(0)] -> [Replicate()]
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
# 2D FSDP + TP all_gather:
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
placements = [Replicate() for _ in value.placements]
value = value.redistribute(
device_mesh=value.device_mesh,
placements=placements,
)
# Call `wait()` to force the tensor to be synchronous with respect
# to the main stream.
# See the discussion in https://github.com/pytorch/pytorch/pull/117799.
value = value.to_local()
if isinstance(value, AsyncCollectiveTensor):
value = value.wait()
return value
return _iterate_state_dict(
state_dict,
sharded_tensor_func,
dtensor_func,
_identity_func,
pg=pg,
device=device,
cpu_offload=cpu_offload,
ranks_only=ranks_only,
type_check=type_check,
)
def _offload_state_dict_to_cpu(
state_dict: Dict[str, Any],
*,
ranks_only: Tuple[int, ...] = tuple(),
cpu_offload_state_dict: Optional[Dict[str, Any]] = None,
cpu_offload_sync: bool = True,
type_check: bool = True,
) -> Dict[str, Any]:
"""
Given a state_dict, this API offload all the tensors to CPU memory.
Args:
state_dict (Dict[str, Any]): the target state_dict.
pg (Optional[dist.ProcessGroup]): the process group that is used to
gather ShardedTensor. Note that gathering a DTensor will use
the DeviceMesh. So this argument will be ignored when gathering a
DTensor.
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
have the same state_dicts. Other ranks will get empty state_dicts.
cpu_offload_state_dict (Optional[Dict[str, Any]]): the CPU state_dict
that will be returned. If this is not None, this API will use
`copy_` to copy the GPU tensor to the tensor in this CPU state_dict.
This CPU state_dict must have exactly the same structure as the
`state_dict` the only difference is that all the tensors in this
CPU state_dict are on CPU memory.
cpu_offload_sync: (bool): flag to decide whether to call `synchronize()`
before this API returns.
type_check: (bool): check if the instance data type is a supported type
that can be saved by DCP. The current supported data types are
torch.Tensor, DTensor, int, float, str, list, dict, None.
Returns:
The gathered state dictionary.
"""
ret = _iterate_state_dict(
state_dict,
_identity_func,
_identity_func,
_identity_func,
pg=None,
device=None,
cpu_offload=True,
ranks_only=ranks_only,
companion_obj=cpu_offload_state_dict,
type_check=type_check,
)
if cpu_offload_state_dict is not None and cpu_offload_sync:
torch.cuda.synchronize()
return ret
def _create_cpu_state_dict(
state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
) -> Dict[str, Any]:
"""
Given a state_dict, create another state_dict with the same structure and elements.
However, all tensors in the returned state_dict are new tensors on CPU. These
tensors can be placed on pin_memory or share_memory based on the provided arguments.
"""
if pin_memory and share_memory:
raise ValueError(
"Cannot allocate both memory on both pin_memory and share_memory"
)
def tensor_func(
obj: torch.Tensor,
pg: Optional[dist.ProcessGroup],
device: Optional[torch.device],
companion_obj: Any,
) -> torch.Tensor:
if len(obj.size()) == 0:
return torch.tensor(0, dtype=obj.dtype)
if share_memory:
return torch.empty(
*tuple(companion_obj.size()), dtype=companion_obj.dtype
).share_memory_()
else:
return torch.empty(
*tuple(companion_obj.size()), dtype=companion_obj.dtype
).pin_memory()
ret = _iterate_state_dict(
state_dict,
_identity_func,
_identity_func,
tensor_func,
pg=None,
device=None,
cpu_offload=False,
ranks_only=tuple(),
companion_obj=state_dict,
type_check=False,
)
return ret
def _check_state_dict_similarity(
state_dict: Dict[str, Any],
compared_state_dict: Dict[str, Any],
) -> bool:
"""
Given two state_dicts, check if the structures are the same. And
if a [key, tensor] pair exist in one state_dict there must be
the a corresponding pait, [key, other_tensor], in the other state_dict,
where tensor and other_tensor have the same size and dtype.
Return the check result.
"""
def tensor_func(
obj: torch.Tensor,
pg: Optional[dist.ProcessGroup],
device: Optional[torch.device],
companion_obj: Any,
) -> torch.Tensor:
if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
raise CompanionMismatch()
return obj
try:
_iterate_state_dict(
state_dict,
_identity_func,
_identity_func,
tensor_func,
pg=None,
device=None,
cpu_offload=False,
ranks_only=tuple(),
companion_obj=compared_state_dict,
type_check=False,
)
except CompanionMismatch:
return False
return True