mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
1. Implement the framework to allow user to choose among `state_dict`, `local_state_dict`, and `sharded_state_dict`. 2. Implement ShardedTensor compatible local_state_dict() and load_local_state_dict(). ghstack-source-id: 149625958 Differential Revision: [D34383925](https://our.internmc.facebook.com/intern/diff/D34383925/) [ghstack-poisoned]
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
from typing import Dict, List, Tuple, Union, Any, Callable, Set, TYPE_CHECKING
|
|
|
|
import torch
|
|
|
|
if TYPE_CHECKING:
|
|
from collections import OrderedDict # noqa: F401
|
|
|
|
"""Useful functions to deal with tensor types with other python container types."""
|
|
|
|
|
|
def _apply_to_tensors(
|
|
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]
|
|
) -> Any:
|
|
"""Recursively apply to all tensor in different kinds of container types."""
|
|
|
|
def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
|
|
if torch.is_tensor(x):
|
|
return fn(x)
|
|
elif isinstance(x, dict):
|
|
return {key: apply(value) for key, value in x.items()}
|
|
elif isinstance(x, (list, tuple, set)):
|
|
return type(x)(apply(el) for el in x)
|
|
else:
|
|
return x
|
|
|
|
return apply(container)
|
|
|
|
|
|
def _replace_by_prefix(
|
|
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
|
old_prefix: str,
|
|
new_prefix: str,
|
|
) -> None:
|
|
"""
|
|
Replace all keys that match a given old_prefix with a new_prefix (in-place).
|
|
|
|
Usage::
|
|
|
|
state_dict = {"layer.xyz": torch.tensor(1)}
|
|
replace_by_prefix_(state_dict, "layer.", "module.layer.")
|
|
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
|
|
"""
|
|
if old_prefix == new_prefix:
|
|
raise ValueError("old_prefix and new_prefix must be distinct")
|
|
for key in list(state_dict.keys()):
|
|
if not key.startswith(old_prefix):
|
|
continue
|
|
new_key = new_prefix + key[len(old_prefix) :]
|
|
state_dict[new_key] = state_dict[key]
|
|
del state_dict[key]
|