import os from typing import Union import torch from torch.distributed.checkpoint import FileSystemReader from torch.distributed.checkpoint._traverse import set_element from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from torch.distributed.checkpoint.metadata import ( Metadata, STATE_DICT_TYPE, TensorStorageMetadata, ) from torch.distributed.checkpoint.state_dict_loader import _load_state_dict __all__ = ["dcp_to_torch_save"] class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): """ Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. Useful for loading in state_dict without first initializing a model, such as when converting a DCP checkpoint into a Torch save file. . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner .. warning:: Because the entire state dict is initialized, It's recommended to only utilize this LoadPlanner on a single rank or process to avoid OOM. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def set_up_planner( self, state_dict: STATE_DICT_TYPE, metadata: Metadata, is_coordinator: bool, ) -> None: assert not state_dict # rebuild the state dict from the metadata for k, v in metadata.state_dict_metadata.items(): if isinstance(v, TensorStorageMetadata): v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] if k in metadata.planner_data: set_element(state_dict, metadata.planner_data[k], v) else: state_dict[k] = v super().set_up_planner(state_dict, metadata, is_coordinator) def dcp_to_torch_save( dcp_checkpoint_dir: Union[str, os.PathLike], torch_save_path: Union[str, os.PathLike], ): """ Given a directory containing a DCP checkpoint, this function will convert it into a Torch save file. Args: dcp_checkpoint_dir: Directory containing the DCP checkpoint. torch_save_path: Filename to store the converted Torch save file. .. warning:: To avoid OOM, it's recommended to only run this function on a single rank. """ sd: STATE_DICT_TYPE = {} storage_reader = FileSystemReader(dcp_checkpoint_dir) _load_state_dict( sd, storage_reader=storage_reader, planner=_EmptyStateDictLoadPlanner(), no_dist=True, ) torch.save(sd, torch_save_path)