pytorch/torch/distributed/fsdp/utils.py
Rohan Varma 782ee6c7e7 [FSDP][Reland] Implement local_state_dict and load_local_state_dict
1. Implement the framework to allow user to choose among `state_dict`, `local_state_dict`, and `sharded_state_dict`.
2. Implement ShardedTensor compatible local_state_dict() and load_local_state_dict().
ghstack-source-id: 149625958

Differential Revision: [D34383925](https://our.internmc.facebook.com/intern/diff/D34383925/)

[ghstack-poisoned]
2022-02-23 07:57:34 -08:00

51 lines
1.6 KiB
Python

from typing import Dict, List, Tuple, Union, Any, Callable, Set, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
"""Useful functions to deal with tensor types with other python container types."""
def _apply_to_tensors(
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x
return apply(container)
def _replace_by_prefix(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
old_prefix: str,
new_prefix: str,
) -> None:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_by_prefix_(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if old_prefix == new_prefix:
raise ValueError("old_prefix and new_prefix must be distinct")
for key in list(state_dict.keys()):
if not key.startswith(old_prefix):
continue
new_key = new_prefix + key[len(old_prefix) :]
state_dict[new_key] = state_dict[key]
del state_dict[key]