mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
as title Differential Revision: [D53718042](https://our.internmc.facebook.com/intern/diff/D53718042/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/119814 Approved by: https://github.com/fegin ghstack dependencies: #119813
81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
import os
|
|
from typing import Union
|
|
|
|
import torch
|
|
from torch.distributed.checkpoint import FileSystemReader
|
|
from torch.distributed.checkpoint._traverse import set_element
|
|
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
|
from torch.distributed.checkpoint.metadata import (
|
|
Metadata,
|
|
STATE_DICT_TYPE,
|
|
TensorStorageMetadata,
|
|
)
|
|
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
|
|
|
|
__all__ = ["dcp_to_torch_save"]
|
|
|
|
|
|
class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
|
"""
|
|
Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
|
|
Useful for loading in state_dict without first initializing a model, such as
|
|
when converting a DCP checkpoint into a Torch save file.
|
|
|
|
. N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
|
|
|
|
.. warning::
|
|
Because the entire state dict is initialized, It's recommended to only utilize
|
|
this LoadPlanner on a single rank or process to avoid OOM.
|
|
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def set_up_planner(
|
|
self,
|
|
state_dict: STATE_DICT_TYPE,
|
|
metadata: Metadata,
|
|
is_coordinator: bool,
|
|
) -> None:
|
|
assert not state_dict
|
|
|
|
# rebuild the state dict from the metadata
|
|
for k, v in metadata.state_dict_metadata.items():
|
|
if isinstance(v, TensorStorageMetadata):
|
|
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
|
|
if k in metadata.planner_data:
|
|
set_element(state_dict, metadata.planner_data[k], v)
|
|
else:
|
|
state_dict[k] = v
|
|
|
|
super().set_up_planner(state_dict, metadata, is_coordinator)
|
|
|
|
|
|
def dcp_to_torch_save(
|
|
dcp_checkpoint_dir: Union[str, os.PathLike],
|
|
torch_save_path: Union[str, os.PathLike],
|
|
):
|
|
"""
|
|
Given a directory containing a DCP checkpoint, this function will convert it into a
|
|
Torch save file.
|
|
|
|
Args:
|
|
dcp_checkpoint_dir: Directory containing the DCP checkpoint.
|
|
torch_save_path: Filename to store the converted Torch save file.
|
|
|
|
.. warning::
|
|
To avoid OOM, it's recommended to only run this function on a single rank.
|
|
"""
|
|
|
|
sd: STATE_DICT_TYPE = {}
|
|
storage_reader = FileSystemReader(dcp_checkpoint_dir)
|
|
|
|
_load_state_dict(
|
|
sd,
|
|
storage_reader=storage_reader,
|
|
planner=_EmptyStateDictLoadPlanner(),
|
|
no_dist=True,
|
|
)
|
|
torch.save(sd, torch_save_path)
|