mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the `torch.distributed._tensor`, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113 Approved by: https://github.com/XilunWu ghstack dependencies: #133305, #133306
357 lines
13 KiB
Python
357 lines
13 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
import dataclasses
|
|
from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch._utils import _get_device_module
|
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
|
from torch.distributed._shard.sharded_tensor.metadata import (
|
|
TensorProperties as ShardTensorProperties,
|
|
)
|
|
from torch.distributed._shard.sharded_tensor.shard import Shard
|
|
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
|
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
|
|
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
|
from torch.distributed.checkpoint.metadata import (
|
|
BytesStorageMetadata,
|
|
ChunkStorageMetadata,
|
|
Metadata,
|
|
MetadataIndex,
|
|
STATE_DICT_TYPE,
|
|
TensorProperties,
|
|
TensorStorageMetadata,
|
|
)
|
|
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
|
|
from torch.distributed.checkpoint.planner_helpers import (
|
|
_create_read_items,
|
|
create_read_items_for_chunk_list,
|
|
)
|
|
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
|
|
from torch.distributed.checkpoint.storage import StorageReader
|
|
from torch.distributed.checkpoint.utils import (
|
|
_element_wise_add,
|
|
_element_wise_sub,
|
|
_normalize_device_info,
|
|
)
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
|
from torch.distributed.remote_device import _remote_device
|
|
from torch.distributed.tensor import DTensor
|
|
|
|
|
|
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
|
|
|
|
|
|
# TODO: Update docstrings for optimizer.py
|
|
__all__ = [
|
|
"load_sharded_optimizer_state_dict",
|
|
]
|
|
|
|
|
|
def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
|
|
if device_type == "cpu":
|
|
return "cpu"
|
|
device_module = _get_device_module(device_type)
|
|
if device_module.is_available():
|
|
return _normalize_device_info(
|
|
device_type, global_rank % device_module.device_count()
|
|
)
|
|
return "cpu"
|
|
|
|
|
|
def _create_colwise_spec(
|
|
pg: Optional[dist.ProcessGroup] = None,
|
|
) -> ChunkShardingSpec:
|
|
pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type
|
|
if pg is None:
|
|
placements = [
|
|
f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}"
|
|
for idx in range(dist.get_world_size())
|
|
]
|
|
else:
|
|
placements = [
|
|
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}"
|
|
for idx in range(pg.size())
|
|
]
|
|
return ChunkShardingSpec(
|
|
dim=0,
|
|
placements=cast(List[Union[_remote_device, str]], placements),
|
|
)
|
|
|
|
|
|
def _is_nested_tensor(val: torch.Tensor) -> bool:
|
|
if type(val) is ShardedTensor:
|
|
if len(val.local_shards()) == 0:
|
|
return False
|
|
if type(val.local_shards()[0].tensor) is ShardedTensor:
|
|
return True
|
|
if type(val.local_shards()[0].tensor) is DTensor:
|
|
raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
|
|
elif type(val) is DTensor and (
|
|
type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
|
|
):
|
|
raise ValueError("Cannot handle nested DTensor")
|
|
return False
|
|
|
|
|
|
def _alloc_tensor(
|
|
props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
|
|
) -> torch.Tensor:
|
|
if device_type == "cpu":
|
|
device = cast(torch.device, _get_device_module(device_type).current_device())
|
|
else:
|
|
device = torch.device(
|
|
device_type, _get_device_module(device_type).current_device()
|
|
)
|
|
|
|
return torch.empty(
|
|
size=size,
|
|
dtype=props.dtype,
|
|
layout=props.layout,
|
|
requires_grad=props.requires_grad,
|
|
pin_memory=props.pin_memory,
|
|
device=device,
|
|
)
|
|
|
|
|
|
def _get_state_dict_2d_layout(
|
|
state_dict: STATE_DICT_TYPE,
|
|
) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]:
|
|
"""
|
|
Load the right TP slice of the optimizer state.
|
|
|
|
This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata.
|
|
We take advantage of the model state_dict producing a sliced ST to figure out what we need to load.
|
|
This is pretty fragile and it might be easier for FSDP to compute this info for us.
|
|
Returns a dictionary where keys are the same of the state_dict and the value is a tuple of
|
|
(offset, size) for the current rank TP slice.
|
|
N.B. The state_dict *MUST* come from FSDP.sharded_state_dict.
|
|
"""
|
|
specs: STATE_DICT_2D_LAYOUT = {}
|
|
dp_pg: Optional[dist.ProcessGroup] = None
|
|
for key, value in state_dict.items():
|
|
specs[key] = (None, value.size())
|
|
if _is_nested_tensor(value):
|
|
assert (
|
|
len(value.local_shards()) == 1
|
|
), "Cannot handle ST with multiple shards"
|
|
assert isinstance(
|
|
value, ShardedTensor
|
|
), "Can only handle nested ShardedTensor"
|
|
shard = value.local_shards()[0]
|
|
specs[key] = (
|
|
shard.metadata.shard_offsets,
|
|
shard.metadata.shard_sizes,
|
|
)
|
|
dp_pg = shard.tensor._process_group # type: ignore[attr-defined]
|
|
|
|
return (
|
|
specs,
|
|
dp_pg,
|
|
)
|
|
|
|
|
|
class _ReaderWithOffset(DefaultLoadPlanner):
|
|
translation: Dict[MetadataIndex, MetadataIndex]
|
|
state_dict: STATE_DICT_TYPE
|
|
metadata: Metadata
|
|
|
|
def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
|
|
super().__init__()
|
|
self.fqn_to_offset = fqn_to_offset
|
|
self.metadata = Metadata({})
|
|
self.state_dict = {}
|
|
self.translation = {}
|
|
|
|
def create_local_plan(self) -> LoadPlan:
|
|
requests = []
|
|
self.translation = {}
|
|
for fqn, obj in self.state_dict.items():
|
|
md = self.metadata.state_dict_metadata[fqn]
|
|
if not isinstance(obj, ShardedTensor):
|
|
requests += _create_read_items(fqn, md, obj)
|
|
continue
|
|
|
|
if fqn not in self.fqn_to_offset:
|
|
requests += _create_read_items(fqn, md, obj)
|
|
continue
|
|
|
|
offset = self.fqn_to_offset[fqn]
|
|
|
|
assert len(obj.local_shards()) == 1
|
|
original_shard = obj.local_shards()[0]
|
|
local_chunks = [
|
|
ChunkStorageMetadata(
|
|
offsets=torch.Size(
|
|
_element_wise_add(original_shard.metadata.shard_offsets, offset)
|
|
),
|
|
sizes=torch.Size(original_shard.metadata.shard_sizes),
|
|
)
|
|
]
|
|
|
|
reqs = create_read_items_for_chunk_list(
|
|
fqn, cast(TensorStorageMetadata, md), local_chunks
|
|
)
|
|
# TODO: The ReadItems will have a displaced MetadataIndex, fix it.
|
|
# TODO: we should change _create_sharded_read_items to have more ergonomic API
|
|
for ri in reqs:
|
|
assert ri.dest_index.offset is not None
|
|
original_offset = _element_wise_sub(ri.dest_index.offset, offset)
|
|
original_index = dataclasses.replace(
|
|
ri.dest_index, offset=torch.Size(original_offset)
|
|
)
|
|
self.translation[ri.dest_index] = original_index
|
|
|
|
requests += reqs
|
|
return LoadPlan(requests)
|
|
|
|
def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
|
|
return super().lookup_tensor(self.translation.get(index, index))
|
|
|
|
|
|
def load_sharded_optimizer_state_dict(
|
|
model_state_dict: STATE_DICT_TYPE,
|
|
optimizer_key: str,
|
|
storage_reader: StorageReader,
|
|
planner: Optional[LoadPlanner] = None,
|
|
) -> STATE_DICT_TYPE:
|
|
"""
|
|
Load a state_dict in conjunction with FSDP sharded optimizer state.
|
|
|
|
This is the current recommended way to checkpoint FSDP.
|
|
>>> # xdoctest: +SKIP
|
|
>>> import torch.distributed.checkpoint as dist_cp
|
|
>>> # Save
|
|
>>> model: torch.nn.Model
|
|
>>> optim_params = model.parameters()
|
|
>>> optim = torch.optim.SGD(optim_params, lr=0.01)
|
|
>>> # Save
|
|
>>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
|
|
>>> state_dict = {
|
|
>>> "optimizer": FSDP.optim_state_dict(model, optim),
|
|
>>> "model": model.state_dict()
|
|
>>> }
|
|
>>> dist_cp.save_state_dict(
|
|
>>> state_dict=optim_state,
|
|
>>> storage_writer=dist_cp.FileSystemWriter("checkpoint"),
|
|
>>> planner=dist_cp.DefaultSavePlanner(),
|
|
>>> )
|
|
>>>
|
|
>>> # Load
|
|
>>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT):
|
|
>>> model_state_dict = model_tp.state_dict()
|
|
>>> checkpoint = {
|
|
>>> "model": model_state_dict
|
|
>>> }
|
|
>>> dist_cp.load_state_dict(
|
|
>>> state_dict=checkpoint,
|
|
>>> storage_reader=dist_cp.FileSystemReader(checkpoint_file),
|
|
>>> planner=dist_cp.DefaultLoadPlanner(),
|
|
>>> )
|
|
>>> model.load_state_dict(checkpoint["model_state"])
|
|
>>>
|
|
>>> optim_state = dist_cp.load_sharded_optimizer_state_dict(
|
|
>>> model_state_dict,
|
|
>>> optimizer_key="optimizer",
|
|
>>> storage_reader=dist_cp.FileSystemReader("checkpoint"),
|
|
>>> )
|
|
>>>
|
|
>>> flattened_osd = FSDP.optim_state_dict_to_load(
|
|
>>> model, optim, optim_state["optimizer"]
|
|
>>> )
|
|
>>>
|
|
>>> optim.load_state_dict(flattened_osd)
|
|
"""
|
|
metadata = storage_reader.read_metadata()
|
|
|
|
layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
|
|
dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type
|
|
device_module = _get_device_module(dp_pg_device_type)
|
|
|
|
if dp_pg is None:
|
|
placements = []
|
|
for i in range(dist.get_world_size()):
|
|
device_info = _normalize_device_info(
|
|
dp_pg_device_type, i % device_module.device_count()
|
|
)
|
|
placements.append(f"rank:{i}/{device_info}")
|
|
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type]
|
|
else:
|
|
sharding_spec = _create_colwise_spec(dp_pg)
|
|
|
|
# Create a state_dict for optimizer state
|
|
state_dict: STATE_DICT_TYPE = {}
|
|
|
|
fqn_to_offset: Dict[str, Sequence[int]] = {}
|
|
for key, value in metadata.state_dict_metadata.items():
|
|
key_path = metadata.planner_data[key]
|
|
if key_path[0] != optimizer_key:
|
|
continue
|
|
|
|
if isinstance(value, BytesStorageMetadata):
|
|
state_dict[key] = "<bytes_io>"
|
|
continue
|
|
|
|
# value: TensorStorageMetadata
|
|
if value.size.numel() == 1:
|
|
state_dict[key] = _alloc_tensor(
|
|
value.properties, value.size, dp_pg_device_type
|
|
)
|
|
elif dp_pg is None:
|
|
state_dict[key] = _create_chunk_sharded_tensor(
|
|
_alloc_tensor(value.properties, value.size, dp_pg_device_type),
|
|
rank=dist.get_rank(),
|
|
world_size=dist.get_world_size(),
|
|
num_devices_per_node=device_module.device_count(),
|
|
pg=_get_default_group(),
|
|
)
|
|
else:
|
|
spec_key = key_path[2]
|
|
alloc_size = layout_specs.get(spec_key, (None, value.size))[1]
|
|
|
|
properties = ShardTensorProperties(
|
|
dtype=value.properties.dtype,
|
|
layout=value.properties.layout,
|
|
requires_grad=value.properties.requires_grad,
|
|
memory_format=value.properties.memory_format,
|
|
pin_memory=value.properties.pin_memory,
|
|
)
|
|
|
|
st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties)
|
|
local_shards = []
|
|
current_rank = dist.get_rank(dp_pg)
|
|
for shard_md in st_md.shards_metadata:
|
|
if cast(_remote_device, shard_md.placement).rank() != current_rank:
|
|
continue
|
|
local_shards.append(
|
|
Shard(
|
|
tensor=_alloc_tensor(
|
|
value.properties, shard_md.shard_sizes, dp_pg_device_type
|
|
),
|
|
metadata=shard_md,
|
|
)
|
|
)
|
|
|
|
st = ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
local_shards, st_md, process_group=dp_pg
|
|
)
|
|
|
|
if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
|
|
fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
|
|
|
|
state_dict[key] = st
|
|
|
|
# Whether we unflatten before or after doesn't matter
|
|
load_state_dict(
|
|
state_dict=state_dict,
|
|
storage_reader=storage_reader,
|
|
# FIXME the type of planner is wrong in load_state_dict
|
|
planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner,
|
|
)
|
|
|
|
state_dict = unflatten_state_dict(state_dict, metadata.planner_data)
|
|
|
|
return state_dict
|