From fdee60769ac0d4668334b018f32c48a44141bed5 Mon Sep 17 00:00:00 2001 From: Meet Vadakkanchery Date: Tue, 4 Mar 2025 13:33:28 +0000 Subject: [PATCH] [DCP] Introduce process based async checkpointing (#147039) Summary: ### Context Background checkpoint upload thread interfering with trainer thread: In [async save API](https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py#L239-L248), the background thread spends a considerable amount of time on CPU-bound tasks (pickling/unpickling several metada objects a.k.a SavePlans) on rank0 during the collective operation; this kind of asymmetric computation heavily contends for GIL with the trainer thread causing GPU util to suffer significantly for the E2E checkpoint duration. ### Solution: Introduce async save via a checkpoint daemon process. This daemon process will be created once (during the first save attempt) and can serve async checkpoint requests for the remainder of training lifetime. Test Plan: Added E2E UTs for process based async save. Differential Revision: D69272583 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147039 Approved by: https://github.com/saumishr --- docs/source/distributed.checkpoint.rst | 3 + .../checkpoint/e2e/test_e2e_save_and_load.py | 29 +- .../distributed/checkpoint/_async_executor.py | 32 ++ .../checkpoint/_async_process_executor.py | 307 ++++++++++++++++++ .../checkpoint/_async_thread_executor.py | 39 +++ torch/distributed/checkpoint/logger.py | 15 + .../checkpoint/state_dict_saver.py | 33 +- 7 files changed, 448 insertions(+), 10 deletions(-) create mode 100644 torch/distributed/checkpoint/_async_executor.py create mode 100644 torch/distributed/checkpoint/_async_process_executor.py create mode 100644 torch/distributed/checkpoint/_async_thread_executor.py diff --git a/docs/source/distributed.checkpoint.rst b/docs/source/distributed.checkpoint.rst index fa5102063a3..21c6e99e8b8 100644 --- a/docs/source/distributed.checkpoint.rst +++ b/docs/source/distributed.checkpoint.rst @@ -27,6 +27,9 @@ Additional resources: .. currentmodule:: torch.distributed.checkpoint.state_dict_saver +.. autoclass:: torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType + :members: + .. autofunction:: save .. autofunction:: async_save .. autofunction:: save_state_dict diff --git a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py index 732656d5b13..e0a4b891844 100644 --- a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py +++ b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py @@ -23,6 +23,7 @@ from torch.distributed.checkpoint.state_dict import ( set_state_dict, ) from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys +from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.distributed_c10d import ReduceOp @@ -214,17 +215,31 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): @with_comms @skip_if_lt_x_gpu(4) @with_temp_dir - @parametrize("cache_staged_state_dict", [False, True]) - def test_e2e_async_cached(self, cache_staged_state_dict): + @parametrize( + "cache_staged_state_dict, async_checkpointer_type", + [ + (False, AsyncCheckpointerType.THREAD), + (True, AsyncCheckpointerType.THREAD), + (False, AsyncCheckpointerType.PROCESS), + (True, AsyncCheckpointerType.PROCESS), + ], + ) + def test_e2e_async_cached(self, cache_staged_state_dict, async_checkpointer_type): self._run_e2e_test( compile=False, model_type=ModelType.FSDP, async_op=True, cache_staged_state_dict=cache_staged_state_dict, + async_checkpointer_type=async_checkpointer_type, ) def _run_e2e_test( - self, compile, model_type, async_op=False, cache_staged_state_dict=False + self, + compile, + model_type, + async_op=False, + cache_staged_state_dict=False, + async_checkpointer_type=None, ): model, optim = self._create_model(compile, ModelType.NONE) _train(model, optim, train_steps=2) @@ -244,7 +259,13 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): writer = DCP.FileSystemWriter( self.temp_dir, cache_staged_state_dict=cache_staged_state_dict ) - f = saver.async_save(sd, storage_writer=writer) + f = saver.async_save( + sd, + storage_writer=writer, + async_checkpointer_type=async_checkpointer_type + if async_checkpointer_type + else AsyncCheckpointerType.THREAD, + ) t = time.monotonic() while not f.done(): time.sleep(1) diff --git a/torch/distributed/checkpoint/_async_executor.py b/torch/distributed/checkpoint/_async_executor.py new file mode 100644 index 00000000000..7da04c12b4b --- /dev/null +++ b/torch/distributed/checkpoint/_async_executor.py @@ -0,0 +1,32 @@ +# pyre-strict +# mypy: allow-untyped-defs +import abc +import os +from concurrent.futures import Future +from typing import Optional, Union + +import torch.distributed as dist +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter + + +class _AsyncCheckpointExecutor(abc.ABC): + @abc.abstractmethod + def execute_save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> Future: + """ + Execute the checkpoint save request asynchronously. + + This method is intended to be used as an abstraction for + implementing async checkpointing. The actual checkpoint save + operation is executed in a separate thread or process depending + on the implementation of this interface. + """ diff --git a/torch/distributed/checkpoint/_async_process_executor.py b/torch/distributed/checkpoint/_async_process_executor.py new file mode 100644 index 00000000000..801c8d79e8d --- /dev/null +++ b/torch/distributed/checkpoint/_async_process_executor.py @@ -0,0 +1,307 @@ +# pyre-strict +# mypy: allow-untyped-defs +import logging +import os +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Union +from uuid import uuid4 + +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor +from torch.distributed.checkpoint.logger import _dcp_method_logger, _init_logger +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter +from torch.distributed.checkpoint.utils import _DistWrapper +from torch.distributed.elastic.agent.server.api import _get_fq_hostname +from torch.distributed.elastic.utils.distributed import get_free_port + + +logger = logging.getLogger() + + +class _CheckpointSaveProcessControlOpts(Enum): + INIT_COMPLETE = "init_complete" + TERMINATE = "terminate" + + +@dataclass(init=False, unsafe_hash=True) +class _CheckpointRequestIdentifier: + checkpoint_id: Union[str, os.PathLike, None] + uuid: str + + def __init__(self, checkpoint_id: Union[str, os.PathLike, None]): + self.checkpoint_id = checkpoint_id + self.uuid = str(uuid4()) + + +@dataclass +class _AsyncCheckpointRequest: + staged_state_dict: STATE_DICT_TYPE + checkpoint_request_id: _CheckpointRequestIdentifier + storage_writer: Optional[StorageWriter] = None + planner: Optional[SavePlanner] = None + + +@dataclass(init=False) +class _ProcessGroupInitInfo: + local_rank: int + global_rank: int + world_size: int + tcp_store_master_addr: str + tcp_store_master_port: int + + def __init__(self, process_group: Optional[dist.ProcessGroup] = None): + self.local_rank = dist.get_node_local_rank(fallback_rank=0) + self.global_rank = dist.get_rank(process_group) + self.world_size = dist.get_world_size(process_group) + + # Let coordinator rank find a free port on the localhost. + # Broadcast the (master_addr, free_port) to all ranks; each rank in the + # checkpoint daemon process will use TCPStore (master_addr, master_port) + # for collective communication. + dist_wrapper: _DistWrapper = _DistWrapper( + group=process_group, + use_dist=True, + coordinator_rank=0, + ) + + def get_master_addr_and_port() -> tuple[str, int]: + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + master_addr = _get_fq_hostname() + return master_addr, get_free_port() + + self.tcp_store_master_addr, self.tcp_store_master_port = dist_wrapper.broadcast( + step="get_master_addr_and_port", + map_fun=get_master_addr_and_port, + ) + + +class _AsyncCheckpointProcess: + def __init__( + self, + pg_init_info: _ProcessGroupInitInfo, + ): + self.ctx = mp.get_context("spawn") + self._mp_queue_send: mp.Queue = self.ctx.Queue() + self._mp_queue_recv: mp.Queue = self.ctx.Queue() + + self._save_process = self.ctx.Process( + target=self._checkpointing_subprocess, + args=( + pg_init_info, + self._mp_queue_send, + self._mp_queue_recv, + ), + daemon=True, + ) + + self._save_process.start() + response = self._wait_for_response() + assert response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE + + def __del__(self) -> None: + if self._save_process.is_alive(): + logger.info("Terminating the checkpoint background process...") + self._mp_queue_send.put(_CheckpointSaveProcessControlOpts.TERMINATE) + self._save_process.join() + + def save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + ) -> Metadata: + # Create a unique identifier to locate requests/responses + # from the checkpoint daemon process. + checkpoint_request_id = _CheckpointRequestIdentifier(checkpoint_id) + async_cp_request = _AsyncCheckpointRequest( + staged_state_dict=staged_state_dict, + checkpoint_request_id=checkpoint_request_id, + storage_writer=storage_writer, + planner=planner, + ) + self._mp_queue_send.put(async_cp_request) + result = self._wait_for_response() + assert isinstance(result, Metadata) + return result + + def _wait_for_response(self) -> Any: + if not self._save_process.is_alive(): + logger.info("Checkpoint background process is dead calling join()...") + self._save_process.join() + raise RuntimeError("Checkpoint background process is dead.") + response = self._mp_queue_recv.get() + if isinstance(response, BaseException): + raise response + return response + + @staticmethod + def _execute_save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_request_id: _CheckpointRequestIdentifier, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + ) -> Metadata: + from torch.distributed.checkpoint.state_dict_saver import save + + metadata = save( + state_dict, + checkpoint_id=checkpoint_request_id.checkpoint_id, + storage_writer=storage_writer, + planner=planner, + ) + return metadata + + @staticmethod + def _checkpointing_subprocess( + pg_init_info: _ProcessGroupInitInfo, + recv: mp.Queue, + send: mp.Queue, + ) -> None: + try: + _init_logger(pg_init_info.global_rank) + + # Setup environment variables for process group initialization. + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" + os.environ["MASTER_ADDR"] = pg_init_info.tcp_store_master_addr + os.environ["MASTER_PORT"] = str(pg_init_info.tcp_store_master_port) + os.environ["LOCAL_RANK"] = str(pg_init_info.local_rank) + os.environ["RANK"] = str(pg_init_info.global_rank) + os.environ["WORLD_SIZE"] = str(pg_init_info.world_size) + + logger.info( + "Initializing dist.ProcessGroup in checkpoint background process" + ) + # NOTE: GLOO backend is enforced here. + dist.init_process_group(backend=dist.Backend.GLOO) + dist.barrier() + + logger.info("Checkpoint background process is running...") + send.put(_CheckpointSaveProcessControlOpts.INIT_COMPLETE) + + # Serving loop. + while True: + logger.info("Waiting for checkpoint save request...") + obj = recv.get() + if ( + isinstance(obj, _CheckpointSaveProcessControlOpts) + and obj == _CheckpointSaveProcessControlOpts.TERMINATE + ): + logger.info("Terminating the checkpoint background process.") + return + assert isinstance(obj, _AsyncCheckpointRequest) + logger.info( + f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004 + ) + + response = _AsyncCheckpointProcess._execute_save( + obj.staged_state_dict, + checkpoint_request_id=obj.checkpoint_request_id, + storage_writer=obj.storage_writer, + planner=obj.planner, + ) + send.put(response) + logger.info( + f"Submitted checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004 + ) + except BaseException as e: + logger.error( + f"Checkpoint background process encountered an exception: {e}" # noqa: G004 + ) + send.put(e) + raise + finally: + logger.info("Checkpoint background process is shutting down...") + dist.destroy_process_group() + + +_CHECKPOINT_PROCESS: Optional[_AsyncCheckpointProcess] = None + + +class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=1) + + @staticmethod + def _execute_save_impl( + *, + pg_init_info: Optional[_ProcessGroupInitInfo], + staged_state_dict: STATE_DICT_TYPE, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> Metadata: + global _CHECKPOINT_PROCESS + if _CHECKPOINT_PROCESS is None: + assert pg_init_info is not None + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = process_group + + @_dcp_method_logger(**ckpt_kwargs) + def create_checkpoint_daemon_process() -> None: + global _CHECKPOINT_PROCESS + _CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info) + + create_checkpoint_daemon_process() + + assert _CHECKPOINT_PROCESS is not None + return _CHECKPOINT_PROCESS.save( + staged_state_dict=staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + ) + + def execute_save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> Future: + """ + NOTE: + + - Checkpoint process is implemented as a daemon process. + The AsyncCheckpointProcess' lifetime is tied to the lifetime of the + main process (e.g. trainer process). + + - The first call to execute_save_in_process() will initialize the checkpoint + daemon process. Subsequent async checkpoint requests will not need process + initialization. Therefore, the first async checkpoint request will take longer to complete. + + - Process initialization can have significant overhead, dominated by latency for all ranks to spawn + a background process + process group initialization in the background process. + """ + + global _CHECKPOINT_PROCESS + pg_init_info: Optional[_ProcessGroupInitInfo] = None + if _CHECKPOINT_PROCESS is None: + # Find a free port on coordinator rank and broadcast + # to all ranks. + pg_init_info = _ProcessGroupInitInfo(process_group) + + f: Future = self._executor.submit( + self._execute_save_impl, + pg_init_info=pg_init_info, + staged_state_dict=staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + ) + f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) + + return f diff --git a/torch/distributed/checkpoint/_async_thread_executor.py b/torch/distributed/checkpoint/_async_thread_executor.py new file mode 100644 index 00000000000..541ad1d8c8e --- /dev/null +++ b/torch/distributed/checkpoint/_async_thread_executor.py @@ -0,0 +1,39 @@ +# pyre-strict +# mypy: allow-untyped-defs +import os +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Optional, Union + +import torch.distributed as dist +from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter + + +class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=1) + + def execute_save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> Future: + from torch.distributed.checkpoint.state_dict_saver import save + + f: Future = self._executor.submit( + save, + staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + ) + f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) + + return f diff --git a/torch/distributed/checkpoint/logger.py b/torch/distributed/checkpoint/logger.py index 55ea8a3fa2a..a8961493cbe 100644 --- a/torch/distributed/checkpoint/logger.py +++ b/torch/distributed/checkpoint/logger.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import functools +import logging import time from typing import Any, Callable, TypeVar from typing_extensions import ParamSpec @@ -9,6 +10,9 @@ import torch.distributed.c10d_logger as c10d_logger from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME +logger = logging.getLogger() + + __all__: list[str] = [] global _dcp_logger @@ -101,3 +105,14 @@ def _dcp_method_logger( return wrapper return decorator + + +def _init_logger(rank: int): + logger.setLevel(logging.INFO) + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + formatter = logging.Formatter( + f"[{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) + logger.addHandler(ch) diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index e36357f9a65..d16c10783c9 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -3,13 +3,23 @@ import inspect import os import warnings -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import Future +from enum import Enum from typing import cast, Optional, Union from typing_extensions import deprecated import torch import torch.distributed as dist from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint._async_executor import ( # noqa: TC001 + _AsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._async_process_executor import ( + _ProcessBasedAsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._async_thread_executor import ( + _ThreadBasedAsyncCheckpointExecutor, +) from torch.distributed.checkpoint._storage_utils import _storage_setup from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.logger import _dcp_method_logger @@ -23,7 +33,14 @@ from torch.distributed.distributed_c10d import _get_default_group from .utils import _api_bc_check, _DistWrapper, _profile -__all__ = ["save_state_dict", "save", "async_save"] +__all__ = ["save_state_dict", "save", "async_save", "AsyncCheckpointerType"] + + +class AsyncCheckpointerType(Enum): + """Enum for async checkpointer type.""" + + THREAD = "thread" + PROCESS = "process" @deprecated( @@ -173,6 +190,7 @@ def async_save( storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, + async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD, ) -> Future: """Asynchronous version of ``save``. This code first de-stages the state_dict on to the staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. @@ -242,16 +260,19 @@ def async_save( staged_state_dict = _create_cpu_state_dict(state_dict) _copy_state_dict(state_dict, staged_state_dict, type_check=False) - executor = ThreadPoolExecutor(max_workers=1) - f: Future = executor.submit( - save, + executor: _AsyncCheckpointExecutor = ( + _ProcessBasedAsyncCheckpointExecutor() + if async_checkpointer_type == AsyncCheckpointerType.PROCESS + else _ThreadBasedAsyncCheckpointExecutor() + ) + + f: Future = executor.execute_save( staged_state_dict, checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, process_group=process_group, ) - f.add_done_callback(lambda f: executor.shutdown(wait=False)) if ( isinstance(storage_writer, AsyncStager)