pytorch/torch/distributed/checkpoint/checkpointer.py
Lucas Pasqualin 5432088098 Adds Checkpointer Wrapper for DCP [3/N] (#114603)
Adds a useful high level wrapper for calling `dist.save/load` with the correct storage readers and writers.

Instead of doing:

```
DCP.save(
    state_dict={...},
    storage_writer=StorageWriter(...)
)

DCP.load(
    state_dict={...},
    storage_reader=StorageReader(...)
)
```

We can now do:

```
checkpointer = Checkpointer(...)

checkpointer.save(state_dict={...})
checkpointer.load(state_dict={...})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114603
Approved by: https://github.com/fegin, https://github.com/wz337
2023-12-08 01:03:21 +00:00

81 lines
3.0 KiB
Python

from typing import Any, Dict, Optional
import torch.distributed as dist
import torch.distributed.checkpoint.state_dict_loader as loader
import torch.distributed.checkpoint.state_dict_saver as saver
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.storage import (
LoadPlanner,
SavePlanner,
StorageReader,
StorageWriter,
)
__all__ = ["Checkpointer"]
class Checkpointer:
"""This base class specefies a high level API for saving and loading
distributed `state_dict` 's. It provides an abstraction over the low-level APIs
provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
:py:meth: `torch.distributed.state_dict_saver.save` and
:py:meth: `torch.distributed.state_dict_loader.load` with the provided storage
readers and writers.
"""
def __init__(
self,
storage_writer: StorageWriter,
storage_reader: StorageReader,
*,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
load_planner: Optional[LoadPlanner] = None,
save_planner: Optional[SavePlanner] = None,
):
"""Initializes the Checkpointer instance.
Args:
storage_writer: Instance of StorageWrite use to perform writes.
storage_reader: StorageReader used to load data from.
process_group: ProcessGroup to be used for cross-rank synchronization.
coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.
no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)
loader_planner: Instance of LoadPlanner to use when loading.
save_planner: Instance of SavePlanner to use when saving.
"""
self.storage_writer = storage_writer
self.storage_reader = storage_reader
self.process_group = process_group
self.coordinator_rank = coordinator_rank
self.no_dist = no_dist
self.load_planner = load_planner
self.save_planner = save_planner
def save(
self,
state_dict: STATE_DICT_TYPE,
):
"""Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
saver.save(
state_dict,
self.storage_writer,
process_group=self.process_group,
coordinator_rank=self.coordinator_rank,
no_dist=self.no_dist,
planner=self.save_planner,
)
def load(self, state_dict: Dict[str, Any]):
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
loader.load(
state_dict,
storage_reader=self.storage_reader,
process_group=self.process_group,
coordinator_rank=self.coordinator_rank,
no_dist=self.no_dist,
planner=self.load_planner,
)