pytorch/torch/distributed/checkpoint/_traverse.py
Wanchao Liang cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00

209 lines
6.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import (
Callable,
cast,
Collection,
List,
Mapping,
MutableMapping,
Optional,
Tuple,
TypeVar,
Union,
)
import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.tensor import DTensor
PATH_ITEM = Union[str, int]
OBJ_PATH = Tuple[PATH_ITEM, ...]
T = TypeVar("T")
STATE_DICT_ITEM = object
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
return isinstance(value, torch.Tensor)
# TODO: update docstring for traverse.py
def traverse_state_dict(
state_dict: STATE_DICT_TYPE,
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
) -> None:
"""
Invoke ``visitor`` for each value recursively in ``state_dict``.
Mapping will be traversed and ``visitor`` will be applied to the leaf elements.
``visitor`` will only be applied to elements in a list or a tuple, if the
container contains tensors or mappings.
"""
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
values: Collection[STATE_DICT_ITEM]
if isinstance(value, Mapping):
return False
elif isinstance(value, list):
values = value
else:
return True
for entry in values:
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
return False
if keep_traversing is not None and keep_traversing(entry):
return False
return True
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
if isinstance(value, Mapping):
for k, v in value.items():
_traverse_obj(path + (str(k),), v)
elif _is_terminal(value):
visitor(path, value)
elif isinstance(value, (list, tuple)):
for i, v in enumerate(value):
_traverse_obj(path + (i,), v)
for key, value in state_dict.items():
_traverse_obj((str(key),), value)
def traverse_state_dict_v_2_3(
state_dict: STATE_DICT_TYPE,
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
) -> None:
"""
Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
to false for all elements.
By default, all collections with at least one ``torch.Tensor`` element are traversed.
Visitor takes a path argument that is a tuple of the keys used to reach it.
"""
# a value is terminal if it has no other containers values inside it
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
values: Collection[STATE_DICT_ITEM]
if isinstance(value, Mapping):
values = value.values()
elif isinstance(value, list):
values = value
else:
return True
for entry in values:
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
return False
if keep_traversing is not None and keep_traversing(entry):
return False
return True
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
if _is_terminal(value):
visitor(path, value)
elif isinstance(value, Mapping):
for k, v in value.items():
_traverse_obj(path + (str(k),), v)
elif isinstance(value, list):
for i, v in enumerate(value):
_traverse_obj(path + (i,), v)
for key, value in state_dict.items():
_traverse_obj((str(key),), value)
def set_element(
root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
) -> None:
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
cur_container = cast(CONTAINER_TYPE, root_dict)
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
while len(lst) <= idx:
lst.append(None)
for i in range(1, len(path)):
prev_key = path[i - 1]
key = path[i]
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
if isinstance(cur_container, Mapping):
cur_container = cast(
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
)
else:
extend_list(cur_container, prev_key)
if cur_container[prev_key] is None:
cur_container[prev_key] = def_val
cur_container = cur_container[prev_key]
key = path[-1]
if type(key) == int:
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
cur_container[key] = value
def get_element(
root_dict: STATE_DICT_TYPE,
path: OBJ_PATH,
default_value: Optional[T] = None,
) -> Optional[T]:
"""Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
cur_value = cast(CONTAINER_TYPE, root_dict)
for part in path:
if type(part) is int:
if not isinstance(cur_value, list) or len(cur_value) < part:
return default_value
elif not isinstance(cur_value, Mapping) or part not in cur_value:
return default_value
cur_value = cast(CONTAINER_TYPE, cur_value[part])
return cast(Optional[T], cur_value)
def _print_nested(
value: STATE_DICT_ITEM,
prefix: str = "",
print_fun: Callable[[str], None] = print,
) -> None:
if type(value) is ShardedTensor:
print_fun(f"{prefix} ShardedTensor size: {value.size()}")
for shard in value.local_shards():
_print_nested(
shard.tensor,
f"{shard.metadata.shard_offsets} ",
print_fun=print_fun,
)
elif type(value) is (DTensor):
print_fun(f"{prefix} DistributedTensor size: {value.size()}")
# TODO: add local offset for _local_tensor in print_nested.
_print_nested(
value._local_tensor,
print_fun=print_fun,
)
elif isinstance(value, torch.Tensor):
print_fun(f"{prefix} Tensor size: {value.size()}")
else:
print_fun(f"{prefix} Type: {type(value)}")
def print_tensor(
path: OBJ_PATH,
value: STATE_DICT_ITEM,
print_fun: Callable[[str], None] = print,
) -> None:
"""
Use this callback with traverse_state_dict to print its content.
By default the content is printed using the builtin ``print`` but this can
be change by passing a different ``print_fun` callable.
"""
_print_nested(value, prefix=str(path), print_fun=print_fun)