mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
reland of https://github.com/pytorch/pytorch/pull/133113 I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :( ---- Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203 Approved by: https://github.com/tianyu-l
209 lines
6.7 KiB
Python
209 lines
6.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from typing import (
|
|
Callable,
|
|
cast,
|
|
Collection,
|
|
List,
|
|
Mapping,
|
|
MutableMapping,
|
|
Optional,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
|
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
|
from torch.distributed.tensor import DTensor
|
|
|
|
|
|
PATH_ITEM = Union[str, int]
|
|
OBJ_PATH = Tuple[PATH_ITEM, ...]
|
|
T = TypeVar("T")
|
|
|
|
STATE_DICT_ITEM = object
|
|
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
|
|
|
|
__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
|
|
|
|
|
|
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
|
|
return isinstance(value, torch.Tensor)
|
|
|
|
|
|
# TODO: update docstring for traverse.py
|
|
def traverse_state_dict(
|
|
state_dict: STATE_DICT_TYPE,
|
|
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
|
|
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
|
|
) -> None:
|
|
"""
|
|
Invoke ``visitor`` for each value recursively in ``state_dict``.
|
|
Mapping will be traversed and ``visitor`` will be applied to the leaf elements.
|
|
``visitor`` will only be applied to elements in a list or a tuple, if the
|
|
container contains tensors or mappings.
|
|
"""
|
|
|
|
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
|
|
values: Collection[STATE_DICT_ITEM]
|
|
if isinstance(value, Mapping):
|
|
return False
|
|
elif isinstance(value, list):
|
|
values = value
|
|
else:
|
|
return True
|
|
|
|
for entry in values:
|
|
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
|
|
return False
|
|
if keep_traversing is not None and keep_traversing(entry):
|
|
return False
|
|
return True
|
|
|
|
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
|
|
if isinstance(value, Mapping):
|
|
for k, v in value.items():
|
|
_traverse_obj(path + (str(k),), v)
|
|
elif _is_terminal(value):
|
|
visitor(path, value)
|
|
elif isinstance(value, (list, tuple)):
|
|
for i, v in enumerate(value):
|
|
_traverse_obj(path + (i,), v)
|
|
|
|
for key, value in state_dict.items():
|
|
_traverse_obj((str(key),), value)
|
|
|
|
|
|
def traverse_state_dict_v_2_3(
|
|
state_dict: STATE_DICT_TYPE,
|
|
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
|
|
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
|
|
) -> None:
|
|
"""
|
|
Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
|
|
to false for all elements.
|
|
By default, all collections with at least one ``torch.Tensor`` element are traversed.
|
|
Visitor takes a path argument that is a tuple of the keys used to reach it.
|
|
"""
|
|
|
|
# a value is terminal if it has no other containers values inside it
|
|
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
|
|
values: Collection[STATE_DICT_ITEM]
|
|
if isinstance(value, Mapping):
|
|
values = value.values()
|
|
elif isinstance(value, list):
|
|
values = value
|
|
else:
|
|
return True
|
|
|
|
for entry in values:
|
|
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
|
|
return False
|
|
if keep_traversing is not None and keep_traversing(entry):
|
|
return False
|
|
return True
|
|
|
|
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
|
|
if _is_terminal(value):
|
|
visitor(path, value)
|
|
elif isinstance(value, Mapping):
|
|
for k, v in value.items():
|
|
_traverse_obj(path + (str(k),), v)
|
|
elif isinstance(value, list):
|
|
for i, v in enumerate(value):
|
|
_traverse_obj(path + (i,), v)
|
|
|
|
for key, value in state_dict.items():
|
|
_traverse_obj((str(key),), value)
|
|
|
|
|
|
def set_element(
|
|
root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
|
|
) -> None:
|
|
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
|
|
cur_container = cast(CONTAINER_TYPE, root_dict)
|
|
|
|
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
|
|
while len(lst) <= idx:
|
|
lst.append(None)
|
|
|
|
for i in range(1, len(path)):
|
|
prev_key = path[i - 1]
|
|
key = path[i]
|
|
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
|
|
|
|
if isinstance(cur_container, Mapping):
|
|
cur_container = cast(
|
|
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
|
|
)
|
|
else:
|
|
extend_list(cur_container, prev_key)
|
|
if cur_container[prev_key] is None:
|
|
cur_container[prev_key] = def_val
|
|
cur_container = cur_container[prev_key]
|
|
|
|
key = path[-1]
|
|
if type(key) == int:
|
|
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
|
|
|
|
cur_container[key] = value
|
|
|
|
|
|
def get_element(
|
|
root_dict: STATE_DICT_TYPE,
|
|
path: OBJ_PATH,
|
|
default_value: Optional[T] = None,
|
|
) -> Optional[T]:
|
|
"""Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
|
|
cur_value = cast(CONTAINER_TYPE, root_dict)
|
|
for part in path:
|
|
if type(part) is int:
|
|
if not isinstance(cur_value, list) or len(cur_value) < part:
|
|
return default_value
|
|
elif not isinstance(cur_value, Mapping) or part not in cur_value:
|
|
return default_value
|
|
|
|
cur_value = cast(CONTAINER_TYPE, cur_value[part])
|
|
return cast(Optional[T], cur_value)
|
|
|
|
|
|
def _print_nested(
|
|
value: STATE_DICT_ITEM,
|
|
prefix: str = "",
|
|
print_fun: Callable[[str], None] = print,
|
|
) -> None:
|
|
if type(value) is ShardedTensor:
|
|
print_fun(f"{prefix} ShardedTensor size: {value.size()}")
|
|
for shard in value.local_shards():
|
|
_print_nested(
|
|
shard.tensor,
|
|
f"{shard.metadata.shard_offsets} ",
|
|
print_fun=print_fun,
|
|
)
|
|
elif type(value) is (DTensor):
|
|
print_fun(f"{prefix} DistributedTensor size: {value.size()}")
|
|
# TODO: add local offset for _local_tensor in print_nested.
|
|
_print_nested(
|
|
value._local_tensor,
|
|
print_fun=print_fun,
|
|
)
|
|
elif isinstance(value, torch.Tensor):
|
|
print_fun(f"{prefix} Tensor size: {value.size()}")
|
|
else:
|
|
print_fun(f"{prefix} Type: {type(value)}")
|
|
|
|
|
|
def print_tensor(
|
|
path: OBJ_PATH,
|
|
value: STATE_DICT_ITEM,
|
|
print_fun: Callable[[str], None] = print,
|
|
) -> None:
|
|
"""
|
|
Use this callback with traverse_state_dict to print its content.
|
|
|
|
By default the content is printed using the builtin ``print`` but this can
|
|
be change by passing a different ``print_fun` callable.
|
|
"""
|
|
_print_nested(value, prefix=str(path), print_fun=print_fun)
|