mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the `torch.distributed._tensor`, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113 Approved by: https://github.com/XilunWu ghstack dependencies: #133305, #133306
169 lines
5.3 KiB
Python
169 lines
5.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from typing import (
|
|
Callable,
|
|
cast,
|
|
Collection,
|
|
List,
|
|
Mapping,
|
|
MutableMapping,
|
|
Optional,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
|
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
|
from torch.distributed.tensor import DTensor
|
|
|
|
|
|
PATH_ITEM = Union[str, int]
|
|
OBJ_PATH = Tuple[PATH_ITEM, ...]
|
|
T = TypeVar("T")
|
|
|
|
STATE_DICT_ITEM = object
|
|
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
|
|
|
|
__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
|
|
|
|
|
|
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
|
|
return isinstance(value, torch.Tensor)
|
|
|
|
|
|
# TODO: update docstring for traverse.py
|
|
def traverse_state_dict(
|
|
state_dict: STATE_DICT_TYPE,
|
|
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
|
|
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
|
|
) -> None:
|
|
"""
|
|
Invoke ``visitor`` for each value recursively in ``state_dict``.
|
|
Mapping, list, and tuple will be flattened and other value types are treated
|
|
as the terminal values and will invoke ``visitor``.
|
|
Mapping is treated as non terminal node and will be flattened.
|
|
List and tuple, on the other hand, will not be flattened unless containing other
|
|
mapping containers or tensors.
|
|
"""
|
|
|
|
# a value is terminal if it has no other containers values inside it
|
|
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
|
|
values: Collection[STATE_DICT_ITEM]
|
|
if isinstance(value, Mapping):
|
|
return False
|
|
elif isinstance(value, list):
|
|
values = value
|
|
else:
|
|
return True
|
|
|
|
for entry in values:
|
|
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
|
|
return False
|
|
if keep_traversing is not None and keep_traversing(entry):
|
|
return False
|
|
return True
|
|
|
|
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
|
|
if isinstance(value, Mapping):
|
|
for k, v in value.items():
|
|
_traverse_obj(path + (str(k),), v)
|
|
elif _is_terminal(value):
|
|
visitor(path, value)
|
|
elif isinstance(value, (list, tuple)):
|
|
for i, v in enumerate(value):
|
|
_traverse_obj(path + (i,), v)
|
|
|
|
for key, value in state_dict.items():
|
|
_traverse_obj((str(key),), value)
|
|
|
|
|
|
def set_element(
|
|
root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
|
|
) -> None:
|
|
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
|
|
cur_container = cast(CONTAINER_TYPE, root_dict)
|
|
|
|
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
|
|
while len(lst) <= idx:
|
|
lst.append(None)
|
|
|
|
for i in range(1, len(path)):
|
|
prev_key = path[i - 1]
|
|
key = path[i]
|
|
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
|
|
|
|
if isinstance(cur_container, Mapping):
|
|
cur_container = cast(
|
|
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
|
|
)
|
|
else:
|
|
extend_list(cur_container, prev_key)
|
|
if cur_container[prev_key] is None:
|
|
cur_container[prev_key] = def_val
|
|
cur_container = cur_container[prev_key]
|
|
|
|
key = path[-1]
|
|
if type(key) == int:
|
|
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
|
|
|
|
cur_container[key] = value
|
|
|
|
|
|
def get_element(
|
|
root_dict: STATE_DICT_TYPE,
|
|
path: OBJ_PATH,
|
|
default_value: Optional[T] = None,
|
|
) -> Optional[T]:
|
|
"""Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
|
|
cur_value = cast(CONTAINER_TYPE, root_dict)
|
|
for part in path:
|
|
if type(part) is int:
|
|
if not isinstance(cur_value, list) or len(cur_value) < part:
|
|
return default_value
|
|
elif not isinstance(cur_value, Mapping) or part not in cur_value:
|
|
return default_value
|
|
|
|
cur_value = cast(CONTAINER_TYPE, cur_value[part])
|
|
return cast(Optional[T], cur_value)
|
|
|
|
|
|
def _print_nested(
|
|
value: STATE_DICT_ITEM,
|
|
prefix: str = "",
|
|
print_fun: Callable[[str], None] = print,
|
|
) -> None:
|
|
if type(value) is ShardedTensor:
|
|
print_fun(f"{prefix} ShardedTensor size: {value.size()}")
|
|
for shard in value.local_shards():
|
|
_print_nested(
|
|
shard.tensor,
|
|
f"{shard.metadata.shard_offsets} ",
|
|
print_fun=print_fun,
|
|
)
|
|
elif type(value) is (DTensor):
|
|
print_fun(f"{prefix} DistributedTensor size: {value.size()}")
|
|
# TODO: add local offset for _local_tensor in print_nested.
|
|
_print_nested(
|
|
value._local_tensor,
|
|
print_fun=print_fun,
|
|
)
|
|
elif isinstance(value, torch.Tensor):
|
|
print_fun(f"{prefix} Tensor size: {value.size()}")
|
|
else:
|
|
print_fun(f"{prefix} Type: {type(value)}")
|
|
|
|
|
|
def print_tensor(
|
|
path: OBJ_PATH,
|
|
value: STATE_DICT_ITEM,
|
|
print_fun: Callable[[str], None] = print,
|
|
) -> None:
|
|
"""
|
|
Use this callback with traverse_state_dict to print its content.
|
|
|
|
By default the content is printed using the builtin ``print`` but this can
|
|
be change by passing a different ``print_fun` callable.
|
|
"""
|
|
_print_nested(value, prefix=str(path), print_fun=print_fun)
|