pytorch/torch/distributed/checkpoint/planner_helpers.py
Saurabh Mishra d4c578082a [DCP] Cache save plan metadata to reduce the collective overhead (#149785)
Summary:
Cache save plan metadata to reduce the collective overhead.

Global plan dedupe and metadata creation are the main overheads on Rank 0. This change saves all this cost for the subsequent saves if the plans do not change. A quick experiment with the 256 rank job, Global step overhead drops by ~99%, from 90s+ to mere 1.5s. 1.5s was mostly spent on creating the checkpoint module directories and near empty collective.

Differential Revision: D71631441

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149785
Approved by: https://github.com/MeetVadakkanchery
2025-03-25 02:00:15 +00:00

491 lines
16 KiB
Python

# mypy: allow-untyped-defs
import io
from typing import Any, Callable, cast
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from .metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
MetadataIndex,
STATE_DICT_TYPE,
STORAGE_TYPES,
TensorProperties,
TensorStorageMetadata,
)
from .planner import (
LoadItemType,
ReadItem,
SavePlan,
TensorWriteData,
WriteItem,
WriteItemType,
)
from .resharding import (
_check_shard_metadata_pair_overlap,
_shards_get_overlap_region_wrt_saved_tensor,
)
__all__: list[str] = ["create_read_items_for_chunk_list"]
def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool:
"""
Compare the two Save plans and return True if they are equal.
Args:
plan (SavePlan): First SavePlan to compare.
other_plan (SavePlan): Second SavePlan to compare.
Returns:
True if the two plans are equal, False otherwise.
"""
if plan.usable != other_plan.usable:
return False
# Both the plans should have the same number of items
if len(plan.items) != len(other_plan.items):
return False
# Both the plans should have the same write items.
for plan_item, other_plan_item in zip(plan.items, other_plan.items):
# Write item type should be same
if plan_item.type != other_plan_item.type:
return False
plan_metadata_index = plan_item.index
other_plan_metadata_index = other_plan_item.index
# Write item metadata_index should be same
if (
plan_metadata_index.fqn != other_plan_metadata_index.fqn
or plan_metadata_index.offset != other_plan_metadata_index.offset
or plan_metadata_index.index != other_plan_metadata_index.index
):
return False
# Write item tensor_data should be present in both the write items plans, if it exists in either of them.
tensor_data = plan_item.tensor_data
other_tensor_data = other_plan_item.tensor_data
if (tensor_data and not other_tensor_data) or (
not tensor_data and other_tensor_data
):
return False
if tensor_data and other_tensor_data:
# Write item tensor_data size should be same
if tensor_data.size != other_tensor_data.size:
return False
# Write item tensor_data chunk should be present in both the write items, if it exists in either of them.
chunk = tensor_data.chunk
other_chunk = other_tensor_data.chunk
if (chunk and not other_chunk) or (not chunk and other_chunk):
return False
# Write item tensor_data chunk offsets and sizes should be same
if chunk and other_chunk:
if (
chunk.offsets != other_chunk.offsets
or chunk.sizes != other_chunk.sizes
):
return False
return True
def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool:
"""
Check if any delta plan is usable, indicating the plan has changed.
Args:
delta_plans (List[SavePlan]): A list of delta plans to check.
Returns:
True if any delta plan is usable, False otherwise.
"""
return any(delta_plan and delta_plan.usable for delta_plan in delta_plans)
def _merge_delta_local_plans(
cached_plans: list[SavePlan],
delta_plans: list[SavePlan],
) -> list[SavePlan]:
"""
Merge a list of delta plans into a single plan.
Args:
cached_plans (List[SavePlan]): A list of cached plans.
delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans
Returns:
A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan.
"""
merged_plans = []
for cached_plan, delta_plan in zip(cached_plans, delta_plans):
if delta_plan and not delta_plan.usable:
merged_plans.append(cached_plan)
else:
merged_plans.append(delta_plan)
return merged_plans
def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
return ChunkStorageMetadata(
offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
)
def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
return ChunkStorageMetadata(
offsets=torch.Size(shard_md.shard_offsets),
sizes=torch.Size(shard_md.shard_sizes),
)
def _sharded_tensor_metadata(
sharded_tensor: ShardedTensor, shard_md: ShardMetadata
) -> TensorWriteData:
shard_properties = sharded_tensor.metadata().tensor_properties
properties = TensorProperties(
dtype=shard_properties.dtype,
layout=shard_properties.layout,
requires_grad=shard_properties.requires_grad,
memory_format=shard_properties.memory_format,
pin_memory=shard_properties.pin_memory,
)
return TensorWriteData(
chunk=_chunk_for_shard(shard_md),
properties=properties,
size=sharded_tensor.metadata().size,
)
def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
sizes, offsets = compute_local_shape_and_global_offset(
tensor.shape, tensor.device_mesh, tensor.placements
)
sizes, offsets = torch.Size(sizes), torch.Size(offsets)
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.SHARD,
tensor_data=TensorWriteData(
chunk=ChunkStorageMetadata(
offsets=offsets,
sizes=sizes,
),
properties=TensorProperties.create_from_tensor(tensor.to_local()),
size=tensor.size(),
),
)
def _create_write_item_for_shard(
fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
) -> WriteItem:
offsets = torch.Size(shard_md.shard_offsets)
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.SHARD,
tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
)
def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
offsets = torch.Size([0] * len(tensor.size()))
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.TENSOR,
tensor_data=TensorWriteData(
chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
properties=TensorProperties.create_from_tensor(tensor),
size=tensor.size(),
),
)
def _create_write_item_for_bytesio(fqn: str, bytes: Any):
return WriteItem(
index=MetadataIndex(fqn),
type=WriteItemType.BYTE_IO,
)
def _create_read_item_for_byteio(
dest_index, dest_offset, storage_index, storage_offset, length
):
return ReadItem(
type=LoadItemType.BYTE_IO,
dest_index=dest_index,
dest_offsets=torch.Size((dest_offset,)),
storage_index=storage_index,
storage_offsets=torch.Size((storage_offset,)),
lengths=torch.Size((length,)),
)
def _create_read_item_for_tensor(
dest_index, dest_offsets, storage_index, storage_offsets, lengths
):
return ReadItem(
type=LoadItemType.TENSOR,
dest_index=dest_index,
dest_offsets=torch.Size(dest_offsets),
storage_index=storage_index,
storage_offsets=torch.Size(storage_offsets),
lengths=torch.Size(lengths),
)
def create_read_items_for_chunk_list(
fqn: str,
checkpoint_md: TensorStorageMetadata,
local_chunks: list[ChunkStorageMetadata],
) -> list[ReadItem]:
"""
Create a list of ``ReadItem`` based on the checkpoint and local chunks.
This applies the resharding algorithm and computes the reads needed
to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
Args:
fqn (str) : The state_dict FQN to pass to ``ReadItem``.
checkpoint_md (TensorStorageMetadata): metadata for a given tensor
from a checkpoint.
local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
loaded.
Returns:
A list of ``ReadItem`` that will satisfy all input chunks.
"""
read_items = []
# this is a naive quadratic algo that can be optimized later
for idx, shard in enumerate(local_chunks):
for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
if not _check_shard_metadata_pair_overlap(shard, storage_md):
continue
storage_offsets = []
dest_offsets = []
lengths = []
for (
_dim,
offset_for_saved_tensor,
offset_for_current_tensor,
length,
) in _shards_get_overlap_region_wrt_saved_tensor(
saved_shard=storage_md, current_shard=shard
):
storage_offsets.append(offset_for_saved_tensor)
dest_offsets.append(offset_for_current_tensor)
lengths.append(length)
read_items.append(
_create_read_item_for_tensor(
dest_index=MetadataIndex(fqn, shard.offsets, idx),
dest_offsets=dest_offsets,
storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
storage_offsets=storage_offsets,
lengths=lengths,
)
)
return read_items
def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
requests = []
for fqn, obj in state_dict.items():
if isinstance(obj, DTensor):
requests.append(_create_write_items_for_dtensor(fqn, obj))
elif isinstance(obj, ShardedTensor):
requests.extend(
_create_write_item_for_shard(fqn, obj, shard_md)
for shard_md in obj.metadata().shards_metadata
)
elif isinstance(obj, torch.Tensor):
requests.append(_create_write_item_for_tensor(fqn, obj))
else:
requests.append(_create_write_item_for_bytesio(fqn, obj))
return SavePlan(requests)
def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
if hasattr(object, "__create_write_items__"):
# DTensor implements _Checkpointable
return object.__create_write_items__(fqn, object)
elif isinstance(object, ShardedTensor):
return [
_create_write_item_for_shard(fqn, object, shard.metadata)
for shard in object.local_shards()
]
elif isinstance(object, torch.Tensor):
return [_create_write_item_for_tensor(fqn, object)]
else:
return [_create_write_item_for_bytesio(fqn, object)]
def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
sizes, offsets = compute_local_shape_and_global_offset(
tensor.shape, tensor.device_mesh, tensor.placements
)
sizes, offsets = torch.Size(sizes), torch.Size(offsets)
return ChunkStorageMetadata(
offsets=offsets,
sizes=sizes,
)
def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
if hasattr(tensor, "__create_chunk_list__"):
# DTensor implements _Checkpointable
local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
elif isinstance(tensor, ShardedTensor):
local_chunks = [
_chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
]
elif isinstance(tensor, torch.Tensor):
local_chunks = [_create_chunk_from_tensor(tensor)]
else:
raise ValueError(
"Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
f",but got {type(tensor)}"
)
return local_chunks
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
if not isinstance(md, BytesStorageMetadata):
try:
local_chunks = _create_chunk_list(obj)
except ValueError as ex:
raise ValueError(
f"Invalid checkpoint metadata for {fqn}, "
+ f"expected BytesStorageMetadata but found {type(md)}",
) from ex
return create_read_items_for_chunk_list(fqn, md, local_chunks)
else:
return [
_create_read_item_for_byteio(
dest_index=MetadataIndex(fqn),
dest_offset=0,
storage_index=MetadataIndex(fqn),
storage_offset=0,
length=0,
)
]
def _init_state_dict(state_dict: dict[str, Any]) -> Any:
"""
Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
"""
def dtensor_func(value: DTensor):
device = getattr(value, "device", None)
if device == torch.device("meta"):
device_type = dist.distributed_c10d._get_pg_default_device().type
device = cast(
torch.device, _get_device_module(device_type).current_device()
)
new_local_tensor = torch.empty_like(value.to_local(), device=device)
# We need to pass shape and stride explicitly, since DTensor might be
# sharded unevenly.
dtensor = DTensor.from_local(
new_local_tensor,
device_mesh=value.device_mesh,
placements=value.placements,
shape=value.size(),
stride=value.stride(),
)
return dtensor
else:
return value
def sharded_tensor_func(value: Any):
device = getattr(value, "device", None)
if device == torch.device("meta"):
raise RuntimeError(
f"Found unsupported type {type(value)} for meta device loading."
)
else:
return value
def tensor_func(value: torch.Tensor):
device = getattr(value, "device", None)
if device == torch.device("meta"):
device_type = dist.distributed_c10d._get_pg_default_device().type
device = cast(
torch.device, _get_device_module(device_type).current_device()
)
tensor = torch.empty_like(value, device=device)
return tensor
else:
return value
_iterate_state_dict(
state_dict,
dtensor_func,
sharded_tensor_func,
tensor_func,
)
def _iterate_state_dict(
iter_object: Any,
dtensor_func: Callable,
sharded_tensor_func: Callable,
tensor_func: Callable,
):
"""
Iterate through the state dict, applying the given functions to each tensor type
and update the state dict in place.
Args:
iter_object (Any): the target state_dict.
sharded_tensor_func (Callable): the function to apply to ShardedTensor
dtensor_func (Callable): the function to apply to DTensor
tensor_func (Callable): the function to apply to Tensor
# TODO: let state_dict_util._iterate_state_dict() to support in place option
so we don't need to have two versions of _iterate_state_dict.
"""
if isinstance(iter_object, DTensor):
return dtensor_func(iter_object)
elif isinstance(iter_object, ShardedTensor):
return sharded_tensor_func(iter_object)
elif isinstance(iter_object, torch.Tensor):
return tensor_func(iter_object)
elif (
isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
or iter_object is None
):
return iter_object
elif isinstance(iter_object, dict):
for key, value in iter_object.items():
iter_object[key] = _iterate_state_dict(
value, dtensor_func, sharded_tensor_func, tensor_func
)
return iter_object
elif isinstance(iter_object, (list, tuple)):
ret = [
_iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func)
for v in iter_object
]
if isinstance(iter_object, tuple):
ret = tuple(ret) # type: ignore[assignment]
return ret