pytorch/torch/distributed/checkpoint/_traverse.py
Wanchao Liang 2ee6b97464 [dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
2024-08-17 05:09:52 +00:00

169 lines
5.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import (
Callable,
cast,
Collection,
List,
Mapping,
MutableMapping,
Optional,
Tuple,
TypeVar,
Union,
)
import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.tensor import DTensor
PATH_ITEM = Union[str, int]
OBJ_PATH = Tuple[PATH_ITEM, ...]
T = TypeVar("T")
STATE_DICT_ITEM = object
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
return isinstance(value, torch.Tensor)
# TODO: update docstring for traverse.py
def traverse_state_dict(
state_dict: STATE_DICT_TYPE,
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
) -> None:
"""
Invoke ``visitor`` for each value recursively in ``state_dict``.
Mapping, list, and tuple will be flattened and other value types are treated
as the terminal values and will invoke ``visitor``.
Mapping is treated as non terminal node and will be flattened.
List and tuple, on the other hand, will not be flattened unless containing other
mapping containers or tensors.
"""
# a value is terminal if it has no other containers values inside it
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
values: Collection[STATE_DICT_ITEM]
if isinstance(value, Mapping):
return False
elif isinstance(value, list):
values = value
else:
return True
for entry in values:
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
return False
if keep_traversing is not None and keep_traversing(entry):
return False
return True
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
if isinstance(value, Mapping):
for k, v in value.items():
_traverse_obj(path + (str(k),), v)
elif _is_terminal(value):
visitor(path, value)
elif isinstance(value, (list, tuple)):
for i, v in enumerate(value):
_traverse_obj(path + (i,), v)
for key, value in state_dict.items():
_traverse_obj((str(key),), value)
def set_element(
root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
) -> None:
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
cur_container = cast(CONTAINER_TYPE, root_dict)
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
while len(lst) <= idx:
lst.append(None)
for i in range(1, len(path)):
prev_key = path[i - 1]
key = path[i]
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
if isinstance(cur_container, Mapping):
cur_container = cast(
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
)
else:
extend_list(cur_container, prev_key)
if cur_container[prev_key] is None:
cur_container[prev_key] = def_val
cur_container = cur_container[prev_key]
key = path[-1]
if type(key) == int:
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
cur_container[key] = value
def get_element(
root_dict: STATE_DICT_TYPE,
path: OBJ_PATH,
default_value: Optional[T] = None,
) -> Optional[T]:
"""Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
cur_value = cast(CONTAINER_TYPE, root_dict)
for part in path:
if type(part) is int:
if not isinstance(cur_value, list) or len(cur_value) < part:
return default_value
elif not isinstance(cur_value, Mapping) or part not in cur_value:
return default_value
cur_value = cast(CONTAINER_TYPE, cur_value[part])
return cast(Optional[T], cur_value)
def _print_nested(
value: STATE_DICT_ITEM,
prefix: str = "",
print_fun: Callable[[str], None] = print,
) -> None:
if type(value) is ShardedTensor:
print_fun(f"{prefix} ShardedTensor size: {value.size()}")
for shard in value.local_shards():
_print_nested(
shard.tensor,
f"{shard.metadata.shard_offsets} ",
print_fun=print_fun,
)
elif type(value) is (DTensor):
print_fun(f"{prefix} DistributedTensor size: {value.size()}")
# TODO: add local offset for _local_tensor in print_nested.
_print_nested(
value._local_tensor,
print_fun=print_fun,
)
elif isinstance(value, torch.Tensor):
print_fun(f"{prefix} Tensor size: {value.size()}")
else:
print_fun(f"{prefix} Type: {type(value)}")
def print_tensor(
path: OBJ_PATH,
value: STATE_DICT_ITEM,
print_fun: Callable[[str], None] = print,
) -> None:
"""
Use this callback with traverse_state_dict to print its content.
By default the content is printed using the builtin ``print`` but this can
be change by passing a different ``print_fun` callable.
"""
_print_nested(value, prefix=str(path), print_fun=print_fun)