[DCP] Introduce process based async checkpointing (#147039)

Summary:
### Context
Background checkpoint upload thread interfering with trainer thread:

In [async save API](https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py#L239-L248), the background thread spends a considerable amount of time on CPU-bound tasks (pickling/unpickling several metada objects a.k.a SavePlans) on rank0 during the collective operation; this kind of asymmetric computation heavily contends for GIL with the trainer thread causing GPU util to suffer significantly for the E2E checkpoint duration.

### Solution:
Introduce async save via a checkpoint daemon process. This daemon process will be created once (during the first save attempt) and can serve async checkpoint requests for the remainder of training lifetime.

Test Plan: Added E2E UTs for process based async save.

Differential Revision: D69272583

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147039
Approved by: https://github.com/saumishr
This commit is contained in:
Meet Vadakkanchery 2025-03-04 13:33:28 +00:00 committed by PyTorch MergeBot
parent 16d07988fc
commit fdee60769a
7 changed files with 448 additions and 10 deletions

View File

@ -27,6 +27,9 @@ Additional resources:
.. currentmodule:: torch.distributed.checkpoint.state_dict_saver
.. autoclass:: torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType
:members:
.. autofunction:: save
.. autofunction:: async_save
.. autofunction:: save_state_dict

View File

@ -23,6 +23,7 @@ from torch.distributed.checkpoint.state_dict import (
set_state_dict,
)
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.utils import CheckpointException
from torch.distributed.distributed_c10d import ReduceOp
@ -214,17 +215,31 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
@parametrize("cache_staged_state_dict", [False, True])
def test_e2e_async_cached(self, cache_staged_state_dict):
@parametrize(
"cache_staged_state_dict, async_checkpointer_type",
[
(False, AsyncCheckpointerType.THREAD),
(True, AsyncCheckpointerType.THREAD),
(False, AsyncCheckpointerType.PROCESS),
(True, AsyncCheckpointerType.PROCESS),
],
)
def test_e2e_async_cached(self, cache_staged_state_dict, async_checkpointer_type):
self._run_e2e_test(
compile=False,
model_type=ModelType.FSDP,
async_op=True,
cache_staged_state_dict=cache_staged_state_dict,
async_checkpointer_type=async_checkpointer_type,
)
def _run_e2e_test(
self, compile, model_type, async_op=False, cache_staged_state_dict=False
self,
compile,
model_type,
async_op=False,
cache_staged_state_dict=False,
async_checkpointer_type=None,
):
model, optim = self._create_model(compile, ModelType.NONE)
_train(model, optim, train_steps=2)
@ -244,7 +259,13 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
writer = DCP.FileSystemWriter(
self.temp_dir, cache_staged_state_dict=cache_staged_state_dict
)
f = saver.async_save(sd, storage_writer=writer)
f = saver.async_save(
sd,
storage_writer=writer,
async_checkpointer_type=async_checkpointer_type
if async_checkpointer_type
else AsyncCheckpointerType.THREAD,
)
t = time.monotonic()
while not f.done():
time.sleep(1)

View File

@ -0,0 +1,32 @@
# pyre-strict
# mypy: allow-untyped-defs
import abc
import os
from concurrent.futures import Future
from typing import Optional, Union
import torch.distributed as dist
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.planner import SavePlanner
from torch.distributed.checkpoint.storage import StorageWriter
class _AsyncCheckpointExecutor(abc.ABC):
@abc.abstractmethod
def execute_save(
self,
staged_state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Future:
"""
Execute the checkpoint save request asynchronously.
This method is intended to be used as an abstraction for
implementing async checkpointing. The actual checkpoint save
operation is executed in a separate thread or process depending
on the implementation of this interface.
"""

View File

@ -0,0 +1,307 @@
# pyre-strict
# mypy: allow-untyped-defs
import logging
import os
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, Union
from uuid import uuid4
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor
from torch.distributed.checkpoint.logger import _dcp_method_logger, _init_logger
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.planner import SavePlanner
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.checkpoint.utils import _DistWrapper
from torch.distributed.elastic.agent.server.api import _get_fq_hostname
from torch.distributed.elastic.utils.distributed import get_free_port
logger = logging.getLogger()
class _CheckpointSaveProcessControlOpts(Enum):
INIT_COMPLETE = "init_complete"
TERMINATE = "terminate"
@dataclass(init=False, unsafe_hash=True)
class _CheckpointRequestIdentifier:
checkpoint_id: Union[str, os.PathLike, None]
uuid: str
def __init__(self, checkpoint_id: Union[str, os.PathLike, None]):
self.checkpoint_id = checkpoint_id
self.uuid = str(uuid4())
@dataclass
class _AsyncCheckpointRequest:
staged_state_dict: STATE_DICT_TYPE
checkpoint_request_id: _CheckpointRequestIdentifier
storage_writer: Optional[StorageWriter] = None
planner: Optional[SavePlanner] = None
@dataclass(init=False)
class _ProcessGroupInitInfo:
local_rank: int
global_rank: int
world_size: int
tcp_store_master_addr: str
tcp_store_master_port: int
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
self.local_rank = dist.get_node_local_rank(fallback_rank=0)
self.global_rank = dist.get_rank(process_group)
self.world_size = dist.get_world_size(process_group)
# Let coordinator rank find a free port on the localhost.
# Broadcast the (master_addr, free_port) to all ranks; each rank in the
# checkpoint daemon process will use TCPStore (master_addr, master_port)
# for collective communication.
dist_wrapper: _DistWrapper = _DistWrapper(
group=process_group,
use_dist=True,
coordinator_rank=0,
)
def get_master_addr_and_port() -> tuple[str, int]:
master_addr = os.environ.get("MASTER_ADDR")
if master_addr is None:
master_addr = _get_fq_hostname()
return master_addr, get_free_port()
self.tcp_store_master_addr, self.tcp_store_master_port = dist_wrapper.broadcast(
step="get_master_addr_and_port",
map_fun=get_master_addr_and_port,
)
class _AsyncCheckpointProcess:
def __init__(
self,
pg_init_info: _ProcessGroupInitInfo,
):
self.ctx = mp.get_context("spawn")
self._mp_queue_send: mp.Queue = self.ctx.Queue()
self._mp_queue_recv: mp.Queue = self.ctx.Queue()
self._save_process = self.ctx.Process(
target=self._checkpointing_subprocess,
args=(
pg_init_info,
self._mp_queue_send,
self._mp_queue_recv,
),
daemon=True,
)
self._save_process.start()
response = self._wait_for_response()
assert response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE
def __del__(self) -> None:
if self._save_process.is_alive():
logger.info("Terminating the checkpoint background process...")
self._mp_queue_send.put(_CheckpointSaveProcessControlOpts.TERMINATE)
self._save_process.join()
def save(
self,
staged_state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
) -> Metadata:
# Create a unique identifier to locate requests/responses
# from the checkpoint daemon process.
checkpoint_request_id = _CheckpointRequestIdentifier(checkpoint_id)
async_cp_request = _AsyncCheckpointRequest(
staged_state_dict=staged_state_dict,
checkpoint_request_id=checkpoint_request_id,
storage_writer=storage_writer,
planner=planner,
)
self._mp_queue_send.put(async_cp_request)
result = self._wait_for_response()
assert isinstance(result, Metadata)
return result
def _wait_for_response(self) -> Any:
if not self._save_process.is_alive():
logger.info("Checkpoint background process is dead calling join()...")
self._save_process.join()
raise RuntimeError("Checkpoint background process is dead.")
response = self._mp_queue_recv.get()
if isinstance(response, BaseException):
raise response
return response
@staticmethod
def _execute_save(
state_dict: STATE_DICT_TYPE,
*,
checkpoint_request_id: _CheckpointRequestIdentifier,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
) -> Metadata:
from torch.distributed.checkpoint.state_dict_saver import save
metadata = save(
state_dict,
checkpoint_id=checkpoint_request_id.checkpoint_id,
storage_writer=storage_writer,
planner=planner,
)
return metadata
@staticmethod
def _checkpointing_subprocess(
pg_init_info: _ProcessGroupInitInfo,
recv: mp.Queue,
send: mp.Queue,
) -> None:
try:
_init_logger(pg_init_info.global_rank)
# Setup environment variables for process group initialization.
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
os.environ["MASTER_ADDR"] = pg_init_info.tcp_store_master_addr
os.environ["MASTER_PORT"] = str(pg_init_info.tcp_store_master_port)
os.environ["LOCAL_RANK"] = str(pg_init_info.local_rank)
os.environ["RANK"] = str(pg_init_info.global_rank)
os.environ["WORLD_SIZE"] = str(pg_init_info.world_size)
logger.info(
"Initializing dist.ProcessGroup in checkpoint background process"
)
# NOTE: GLOO backend is enforced here.
dist.init_process_group(backend=dist.Backend.GLOO)
dist.barrier()
logger.info("Checkpoint background process is running...")
send.put(_CheckpointSaveProcessControlOpts.INIT_COMPLETE)
# Serving loop.
while True:
logger.info("Waiting for checkpoint save request...")
obj = recv.get()
if (
isinstance(obj, _CheckpointSaveProcessControlOpts)
and obj == _CheckpointSaveProcessControlOpts.TERMINATE
):
logger.info("Terminating the checkpoint background process.")
return
assert isinstance(obj, _AsyncCheckpointRequest)
logger.info(
f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004
)
response = _AsyncCheckpointProcess._execute_save(
obj.staged_state_dict,
checkpoint_request_id=obj.checkpoint_request_id,
storage_writer=obj.storage_writer,
planner=obj.planner,
)
send.put(response)
logger.info(
f"Submitted checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004
)
except BaseException as e:
logger.error(
f"Checkpoint background process encountered an exception: {e}" # noqa: G004
)
send.put(e)
raise
finally:
logger.info("Checkpoint background process is shutting down...")
dist.destroy_process_group()
_CHECKPOINT_PROCESS: Optional[_AsyncCheckpointProcess] = None
class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
def __init__(self) -> None:
self._executor = ThreadPoolExecutor(max_workers=1)
@staticmethod
def _execute_save_impl(
*,
pg_init_info: Optional[_ProcessGroupInitInfo],
staged_state_dict: STATE_DICT_TYPE,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Metadata:
global _CHECKPOINT_PROCESS
if _CHECKPOINT_PROCESS is None:
assert pg_init_info is not None
ckpt_kwargs = {}
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
ckpt_kwargs["process_group"] = process_group
@_dcp_method_logger(**ckpt_kwargs)
def create_checkpoint_daemon_process() -> None:
global _CHECKPOINT_PROCESS
_CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info)
create_checkpoint_daemon_process()
assert _CHECKPOINT_PROCESS is not None
return _CHECKPOINT_PROCESS.save(
staged_state_dict=staged_state_dict,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
)
def execute_save(
self,
staged_state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Future:
"""
NOTE:
- Checkpoint process is implemented as a daemon process.
The AsyncCheckpointProcess' lifetime is tied to the lifetime of the
main process (e.g. trainer process).
- The first call to execute_save_in_process() will initialize the checkpoint
daemon process. Subsequent async checkpoint requests will not need process
initialization. Therefore, the first async checkpoint request will take longer to complete.
- Process initialization can have significant overhead, dominated by latency for all ranks to spawn
a background process + process group initialization in the background process.
"""
global _CHECKPOINT_PROCESS
pg_init_info: Optional[_ProcessGroupInitInfo] = None
if _CHECKPOINT_PROCESS is None:
# Find a free port on coordinator rank and broadcast
# to all ranks.
pg_init_info = _ProcessGroupInitInfo(process_group)
f: Future = self._executor.submit(
self._execute_save_impl,
pg_init_info=pg_init_info,
staged_state_dict=staged_state_dict,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
)
f.add_done_callback(lambda f: self._executor.shutdown(wait=False))
return f

View File

@ -0,0 +1,39 @@
# pyre-strict
# mypy: allow-untyped-defs
import os
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Optional, Union
import torch.distributed as dist
from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.planner import SavePlanner
from torch.distributed.checkpoint.storage import StorageWriter
class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
def __init__(self) -> None:
self._executor = ThreadPoolExecutor(max_workers=1)
def execute_save(
self,
staged_state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Future:
from torch.distributed.checkpoint.state_dict_saver import save
f: Future = self._executor.submit(
save,
staged_state_dict,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
process_group=process_group,
)
f.add_done_callback(lambda f: self._executor.shutdown(wait=False))
return f

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import functools
import logging
import time
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec
@ -9,6 +10,9 @@ import torch.distributed.c10d_logger as c10d_logger
from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
logger = logging.getLogger()
__all__: list[str] = []
global _dcp_logger
@ -101,3 +105,14 @@ def _dcp_method_logger(
return wrapper
return decorator
def _init_logger(rank: int):
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
f"[{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logger.addHandler(ch)

View File

@ -3,13 +3,23 @@
import inspect
import os
import warnings
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import Future
from enum import Enum
from typing import cast, Optional, Union
from typing_extensions import deprecated
import torch
import torch.distributed as dist
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint._async_executor import ( # noqa: TC001
_AsyncCheckpointExecutor,
)
from torch.distributed.checkpoint._async_process_executor import (
_ProcessBasedAsyncCheckpointExecutor,
)
from torch.distributed.checkpoint._async_thread_executor import (
_ThreadBasedAsyncCheckpointExecutor,
)
from torch.distributed.checkpoint._storage_utils import _storage_setup
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.logger import _dcp_method_logger
@ -23,7 +33,14 @@ from torch.distributed.distributed_c10d import _get_default_group
from .utils import _api_bc_check, _DistWrapper, _profile
__all__ = ["save_state_dict", "save", "async_save"]
__all__ = ["save_state_dict", "save", "async_save", "AsyncCheckpointerType"]
class AsyncCheckpointerType(Enum):
"""Enum for async checkpointer type."""
THREAD = "thread"
PROCESS = "process"
@deprecated(
@ -173,6 +190,7 @@ def async_save(
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD,
) -> Future:
"""Asynchronous version of ``save``. This code first de-stages the state_dict on to the
staging storage (defaults to CPU memory), and then calls the `save` in a separate thread.
@ -242,16 +260,19 @@ def async_save(
staged_state_dict = _create_cpu_state_dict(state_dict)
_copy_state_dict(state_dict, staged_state_dict, type_check=False)
executor = ThreadPoolExecutor(max_workers=1)
f: Future = executor.submit(
save,
executor: _AsyncCheckpointExecutor = (
_ProcessBasedAsyncCheckpointExecutor()
if async_checkpointer_type == AsyncCheckpointerType.PROCESS
else _ThreadBasedAsyncCheckpointExecutor()
)
f: Future = executor.execute_save(
staged_state_dict,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
process_group=process_group,
)
f.add_done_callback(lambda f: executor.shutdown(wait=False))
if (
isinstance(storage_writer, AsyncStager)