pytorch/torch/distributed/checkpoint/_experimental/checkpointer.py
Teja dd3e7170c2 Add async checkpointing impl to experimental checkpointer and add a builder API (#156927)
1. Adds an AsyncCheckpointer with out-of-process checkpointing and state_dict_stager with shared memory, pinned memory and Zero Overhead Support.

2. Adds two conveinient functions to create sync/async checkpointers

Differential Revision: [D77336833](https://our.internmc.facebook.com/intern/diff/D77336833/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156927
Approved by: https://github.com/pradeepfn
2025-07-03 22:49:20 +00:00

342 lines
12 KiB
Python

import abc
import logging
from concurrent.futures import Future
from typing import Any, Optional, TypeVar
from .checkpoint_process import CheckpointProcess
from .checkpoint_reader import CheckpointReader
from .checkpoint_writer import CheckpointWriter
from .staging import CheckpointStager
from .types import STATE_DICT
from .utils import wrap_future
logger = logging.getLogger(__name__)
LOG_INTERVAL = 60
T = TypeVar("T")
class Checkpointer(abc.ABC):
"""
WARNING: This class is experimental, and is created to validate certain ideas,
and is subjected to change or deprecation and we strong discourage any usages at
this time.
Abstract base class that defines the API for checkpointing.
This class defines the interface for coordinating the writing and loading of model
state dictionaries to and from storage. It provides abstract methods to save and load model states
with support for both synchronous and asynchronous operations.
Concrete implementations of this class must implement all the abstract methods.
"""
@abc.abstractmethod
def save(
self,
state_dict: STATE_DICT,
path: str,
**kwargs: dict[str, Any],
) -> Optional[tuple[Future, Future]]:
"""
Save a state dictionary to storage.
Args:
state_dict: The state dictionary to save.
path: The path where the checkpoint should be saved.
**kwargs: Additional keyword arguments to pass to the writer.
Returns:
For synchronous implementations: None
For asynchronous implementations: tuple of (stage_future, write_future)
representing the staging and writing operations.
"""
@abc.abstractmethod
def load(
self,
path: str,
state_dict: Optional[STATE_DICT] = None,
*,
default_map_location: Any = None,
strict: bool = False,
**kwargs: dict[str, Any],
) -> STATE_DICT:
"""
Load a state dictionary from storage.
Args:
path: The path from which to load the checkpoint.
state_dict: Optional state dictionary to update with loaded values.
If provided, only keys in this dictionary will be loaded.
default_map_location: Device mapping function or device name for relocating tensors.
strict: If True, raises an error when there are missing keys in the checkpoint.
**kwargs: Additional keyword arguments to pass to the reader.
Returns:
The loaded state dictionary.
"""
@abc.abstractmethod
def close(self) -> None:
"""
Close the checkpointer and release any resources.
This method should be called when the checkpointer is no longer needed to ensure
proper cleanup of resources.
"""
class SyncCheckpointer(Checkpointer):
"""
Synchronous implementation of Checkpointer.
This class coordinates the writing and loading of model state dictionaries to and from storage
using only synchronous operations. It provides a simple, efficient interface for checkpoint
operations without async overhead.
Attributes:
_writer: CheckpointWriter for writing state dictionaries to storage.
_reader: CheckpointReader for reading state dictionaries from storage.
Example:
checkpointer = SyncCheckpointer(writer=writer, reader=reader)
checkpointer.save(state_dict, path)
loaded_state_dict = checkpointer.load(path)
"""
def __init__(
self,
writer: CheckpointWriter,
reader: CheckpointReader,
):
"""
Initialize a synchronous checkpointer.
Args:
writer: CheckpointWriter for writing checkpoints to storage.
reader: CheckpointReader for reading checkpoints from storage.
"""
self._writer = writer
self._reader = reader
def save(
self,
state_dict: STATE_DICT,
path: str,
**kwargs: dict[str, Any],
) -> Optional[tuple[Future, Future]]:
"""
Save a state dictionary to storage synchronously.
Args:
state_dict: The state dictionary to save.
path: The path where the checkpoint should be saved.
**kwargs: Additional keyword arguments to pass to the writer.
Returns:
Always returns None as operations are synchronous.
Example:
checkpointer.save(state_dict, "/path/to/checkpoint")
"""
logger.debug("Saving checkpoint synchronously to %s", path)
self._writer.write(state_dict, path, **kwargs)
return None
def load(
self,
path: str,
state_dict: Optional[STATE_DICT] = None,
*,
default_map_location: Any = None,
strict: bool = False,
**kwargs: dict[str, Any],
) -> STATE_DICT:
"""
Load a state dictionary from storage.
Args:
path: The path from which to load the checkpoint.
state_dict: Optional state dictionary to update with loaded values.
If provided, only keys in this dictionary will be loaded.
default_map_location: Device mapping function or device name for relocating tensors.
strict: If True, raises an error when there are missing keys in the checkpoint.
**kwargs: Additional keyword arguments to pass to the reader.
Returns:
The loaded state dictionary.
Raises:
RuntimeError: If strict=True and there are missing keys in the checkpoint.
FileNotFoundError: If the checkpoint file is not found.
"""
logger.info("Loading checkpoint from %s", path)
loaded_state_dict, missing_keys = self._reader.read(
path=path,
state_dict=state_dict,
map_location=default_map_location,
**kwargs,
)
if strict and missing_keys is not None and missing_keys != []:
raise RuntimeError(f"Checkpoint at {path} is missing keys: {missing_keys}")
return loaded_state_dict
def close(self) -> None:
"""
Close the checkpointer and release any resources.
This method should be called when the checkpointer is no longer needed to ensure
proper cleanup of resources.
"""
self._writer.close()
logger.info("SyncCheckpointer closed")
class AsyncCheckpointer(Checkpointer):
"""
Asynchronous implementation of Checkpointer.
This class coordinates the writing and loading of model state dictionaries to and from storage
using asynchronous operations for saving. It provides efficient async checkpoint operations
with staging and background writing capabilities.
Attributes:
_reader: CheckpointReader for reading state dictionaries from storage.
_checkpoint_stager: Stager for async operations.
_checkpoint_process: Process for async operations.
_write_future: Future representing the ongoing async write operation.
Example:
checkpointer = AsyncCheckpointer(
reader=reader,
checkpoint_stager=stager,
checkpoint_process=process
)
stage_future, write_future = checkpointer.save(state_dict, path)
# ... do other work ...
write_future.result() # Wait for completion
"""
def __init__(
self,
checkpoint_stager: CheckpointStager,
checkpoint_process: CheckpointProcess,
reader: CheckpointReader,
):
"""
Initialize an asynchronous checkpointer.
Args:
checkpoint_stager: Stager for async operations.
checkpoint_process: Process for async operations.
reader: CheckpointReader for reading checkpoints from storage.
"""
self._reader = reader
self._checkpoint_stager = checkpoint_stager
self._checkpoint_process = checkpoint_process
self._write_future: Optional[Future[Any]] = None
def save(
self,
state_dict: STATE_DICT,
path: str,
**kwargs: Any,
) -> Optional[tuple[Future, Future]]:
"""
Save a state dictionary to storage asynchronously.
Args:
state_dict: The state dictionary to save.
path: The path where the checkpoint should be saved.
**kwargs: Additional keyword arguments to pass to the stager and writer.
Returns:
A tuple of (stage_future, write_future) representing the staging and writing operations.
Example:
stage_future, write_future = checkpointer.save(state_dict, "/path/to/checkpoint")
# ... do other work ...
write_future.result() # Wait for completion
"""
logger.info(
"Initiating checkpoint save to %s. Will wait for prev checkpoints to complete.",
path,
)
# Wait for previous checkpoint ops to finish and verify they are successful
if self._write_future is not None:
self._write_future.result()
logger.debug("Starting state dictionary staging")
staging_result = self._checkpoint_stager.stage(
state_dict=state_dict,
**kwargs,
)
logger.debug("Starting checkpoint write to %s", path)
self._write_future = self._checkpoint_process.write(
staging_result, path, **kwargs
)
logger.info("Checkpoint save to %s initiated", path)
# Return futures for the staging and writing operations
if self._write_future is not None:
return wrap_future(staging_result), self._write_future
else:
# This should not happen since we just assigned _write_future above
raise RuntimeError("Write future is unexpectedly None")
def load(
self,
path: str,
state_dict: Optional[STATE_DICT] = None,
*,
default_map_location: Any = None,
strict: bool = False,
**kwargs: Any,
) -> STATE_DICT:
"""
Load a state dictionary from storage.
Loading is always performed synchronously, even in AsyncCheckpointer.
Args:
path: The path from which to load the checkpoint.
state_dict: Optional state dictionary to update with loaded values.
If provided, only keys in this dictionary will be loaded.
default_map_location: Device mapping function or device name for relocating tensors.
strict: If True, raises an error when there are missing keys in the checkpoint.
**kwargs: Additional keyword arguments to pass to the reader.
Returns:
The loaded state dictionary.
Raises:
RuntimeError: If strict=True and there are missing keys in the checkpoint.
FileNotFoundError: If the checkpoint file is not found.
"""
logger.info("Loading checkpoint from %s", path)
loaded_state_dict, missing_keys = self._reader.read(
path=path,
state_dict=state_dict,
map_location=default_map_location,
**kwargs,
)
if strict and missing_keys is not None and missing_keys != []:
raise RuntimeError(f"Checkpoint at {path} is missing keys: {missing_keys}")
return loaded_state_dict
def close(self) -> None:
"""
Close the checkpointer and release any resources.
This method should be called when the checkpointer is no longer needed to ensure
proper cleanup of async resources.
"""
self._checkpoint_stager.close()
self._checkpoint_process.close()
logger.info("AsyncCheckpointer closed")