mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DCP] Introduce process based async checkpointing (#147039)
Summary: ### Context Background checkpoint upload thread interfering with trainer thread: In [async save API](https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py#L239-L248), the background thread spends a considerable amount of time on CPU-bound tasks (pickling/unpickling several metada objects a.k.a SavePlans) on rank0 during the collective operation; this kind of asymmetric computation heavily contends for GIL with the trainer thread causing GPU util to suffer significantly for the E2E checkpoint duration. ### Solution: Introduce async save via a checkpoint daemon process. This daemon process will be created once (during the first save attempt) and can serve async checkpoint requests for the remainder of training lifetime. Test Plan: Added E2E UTs for process based async save. Differential Revision: D69272583 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147039 Approved by: https://github.com/saumishr
This commit is contained in:
parent
16d07988fc
commit
fdee60769a
|
|
@ -27,6 +27,9 @@ Additional resources:
|
|||
|
||||
.. currentmodule:: torch.distributed.checkpoint.state_dict_saver
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType
|
||||
:members:
|
||||
|
||||
.. autofunction:: save
|
||||
.. autofunction:: async_save
|
||||
.. autofunction:: save_state_dict
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
set_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
|
||||
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.distributed.checkpoint.utils import CheckpointException
|
||||
from torch.distributed.distributed_c10d import ReduceOp
|
||||
|
|
@ -214,17 +215,31 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
@parametrize("cache_staged_state_dict", [False, True])
|
||||
def test_e2e_async_cached(self, cache_staged_state_dict):
|
||||
@parametrize(
|
||||
"cache_staged_state_dict, async_checkpointer_type",
|
||||
[
|
||||
(False, AsyncCheckpointerType.THREAD),
|
||||
(True, AsyncCheckpointerType.THREAD),
|
||||
(False, AsyncCheckpointerType.PROCESS),
|
||||
(True, AsyncCheckpointerType.PROCESS),
|
||||
],
|
||||
)
|
||||
def test_e2e_async_cached(self, cache_staged_state_dict, async_checkpointer_type):
|
||||
self._run_e2e_test(
|
||||
compile=False,
|
||||
model_type=ModelType.FSDP,
|
||||
async_op=True,
|
||||
cache_staged_state_dict=cache_staged_state_dict,
|
||||
async_checkpointer_type=async_checkpointer_type,
|
||||
)
|
||||
|
||||
def _run_e2e_test(
|
||||
self, compile, model_type, async_op=False, cache_staged_state_dict=False
|
||||
self,
|
||||
compile,
|
||||
model_type,
|
||||
async_op=False,
|
||||
cache_staged_state_dict=False,
|
||||
async_checkpointer_type=None,
|
||||
):
|
||||
model, optim = self._create_model(compile, ModelType.NONE)
|
||||
_train(model, optim, train_steps=2)
|
||||
|
|
@ -244,7 +259,13 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
writer = DCP.FileSystemWriter(
|
||||
self.temp_dir, cache_staged_state_dict=cache_staged_state_dict
|
||||
)
|
||||
f = saver.async_save(sd, storage_writer=writer)
|
||||
f = saver.async_save(
|
||||
sd,
|
||||
storage_writer=writer,
|
||||
async_checkpointer_type=async_checkpointer_type
|
||||
if async_checkpointer_type
|
||||
else AsyncCheckpointerType.THREAD,
|
||||
)
|
||||
t = time.monotonic()
|
||||
while not f.done():
|
||||
time.sleep(1)
|
||||
|
|
|
|||
32
torch/distributed/checkpoint/_async_executor.py
Normal file
32
torch/distributed/checkpoint/_async_executor.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# pyre-strict
|
||||
# mypy: allow-untyped-defs
|
||||
import abc
|
||||
import os
|
||||
from concurrent.futures import Future
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||
from torch.distributed.checkpoint.planner import SavePlanner
|
||||
from torch.distributed.checkpoint.storage import StorageWriter
|
||||
|
||||
|
||||
class _AsyncCheckpointExecutor(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def execute_save(
|
||||
self,
|
||||
staged_state_dict: STATE_DICT_TYPE,
|
||||
*,
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_writer: Optional[StorageWriter] = None,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Future:
|
||||
"""
|
||||
Execute the checkpoint save request asynchronously.
|
||||
|
||||
This method is intended to be used as an abstraction for
|
||||
implementing async checkpointing. The actual checkpoint save
|
||||
operation is executed in a separate thread or process depending
|
||||
on the implementation of this interface.
|
||||
"""
|
||||
307
torch/distributed/checkpoint/_async_process_executor.py
Normal file
307
torch/distributed/checkpoint/_async_process_executor.py
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
# pyre-strict
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor
|
||||
from torch.distributed.checkpoint.logger import _dcp_method_logger, _init_logger
|
||||
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
|
||||
from torch.distributed.checkpoint.planner import SavePlanner
|
||||
from torch.distributed.checkpoint.storage import StorageWriter
|
||||
from torch.distributed.checkpoint.utils import _DistWrapper
|
||||
from torch.distributed.elastic.agent.server.api import _get_fq_hostname
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class _CheckpointSaveProcessControlOpts(Enum):
|
||||
INIT_COMPLETE = "init_complete"
|
||||
TERMINATE = "terminate"
|
||||
|
||||
|
||||
@dataclass(init=False, unsafe_hash=True)
|
||||
class _CheckpointRequestIdentifier:
|
||||
checkpoint_id: Union[str, os.PathLike, None]
|
||||
uuid: str
|
||||
|
||||
def __init__(self, checkpoint_id: Union[str, os.PathLike, None]):
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.uuid = str(uuid4())
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AsyncCheckpointRequest:
|
||||
staged_state_dict: STATE_DICT_TYPE
|
||||
checkpoint_request_id: _CheckpointRequestIdentifier
|
||||
storage_writer: Optional[StorageWriter] = None
|
||||
planner: Optional[SavePlanner] = None
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class _ProcessGroupInitInfo:
|
||||
local_rank: int
|
||||
global_rank: int
|
||||
world_size: int
|
||||
tcp_store_master_addr: str
|
||||
tcp_store_master_port: int
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
|
||||
self.local_rank = dist.get_node_local_rank(fallback_rank=0)
|
||||
self.global_rank = dist.get_rank(process_group)
|
||||
self.world_size = dist.get_world_size(process_group)
|
||||
|
||||
# Let coordinator rank find a free port on the localhost.
|
||||
# Broadcast the (master_addr, free_port) to all ranks; each rank in the
|
||||
# checkpoint daemon process will use TCPStore (master_addr, master_port)
|
||||
# for collective communication.
|
||||
dist_wrapper: _DistWrapper = _DistWrapper(
|
||||
group=process_group,
|
||||
use_dist=True,
|
||||
coordinator_rank=0,
|
||||
)
|
||||
|
||||
def get_master_addr_and_port() -> tuple[str, int]:
|
||||
master_addr = os.environ.get("MASTER_ADDR")
|
||||
if master_addr is None:
|
||||
master_addr = _get_fq_hostname()
|
||||
return master_addr, get_free_port()
|
||||
|
||||
self.tcp_store_master_addr, self.tcp_store_master_port = dist_wrapper.broadcast(
|
||||
step="get_master_addr_and_port",
|
||||
map_fun=get_master_addr_and_port,
|
||||
)
|
||||
|
||||
|
||||
class _AsyncCheckpointProcess:
|
||||
def __init__(
|
||||
self,
|
||||
pg_init_info: _ProcessGroupInitInfo,
|
||||
):
|
||||
self.ctx = mp.get_context("spawn")
|
||||
self._mp_queue_send: mp.Queue = self.ctx.Queue()
|
||||
self._mp_queue_recv: mp.Queue = self.ctx.Queue()
|
||||
|
||||
self._save_process = self.ctx.Process(
|
||||
target=self._checkpointing_subprocess,
|
||||
args=(
|
||||
pg_init_info,
|
||||
self._mp_queue_send,
|
||||
self._mp_queue_recv,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
self._save_process.start()
|
||||
response = self._wait_for_response()
|
||||
assert response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._save_process.is_alive():
|
||||
logger.info("Terminating the checkpoint background process...")
|
||||
self._mp_queue_send.put(_CheckpointSaveProcessControlOpts.TERMINATE)
|
||||
self._save_process.join()
|
||||
|
||||
def save(
|
||||
self,
|
||||
staged_state_dict: STATE_DICT_TYPE,
|
||||
*,
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_writer: Optional[StorageWriter] = None,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
) -> Metadata:
|
||||
# Create a unique identifier to locate requests/responses
|
||||
# from the checkpoint daemon process.
|
||||
checkpoint_request_id = _CheckpointRequestIdentifier(checkpoint_id)
|
||||
async_cp_request = _AsyncCheckpointRequest(
|
||||
staged_state_dict=staged_state_dict,
|
||||
checkpoint_request_id=checkpoint_request_id,
|
||||
storage_writer=storage_writer,
|
||||
planner=planner,
|
||||
)
|
||||
self._mp_queue_send.put(async_cp_request)
|
||||
result = self._wait_for_response()
|
||||
assert isinstance(result, Metadata)
|
||||
return result
|
||||
|
||||
def _wait_for_response(self) -> Any:
|
||||
if not self._save_process.is_alive():
|
||||
logger.info("Checkpoint background process is dead calling join()...")
|
||||
self._save_process.join()
|
||||
raise RuntimeError("Checkpoint background process is dead.")
|
||||
response = self._mp_queue_recv.get()
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _execute_save(
|
||||
state_dict: STATE_DICT_TYPE,
|
||||
*,
|
||||
checkpoint_request_id: _CheckpointRequestIdentifier,
|
||||
storage_writer: Optional[StorageWriter] = None,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
) -> Metadata:
|
||||
from torch.distributed.checkpoint.state_dict_saver import save
|
||||
|
||||
metadata = save(
|
||||
state_dict,
|
||||
checkpoint_id=checkpoint_request_id.checkpoint_id,
|
||||
storage_writer=storage_writer,
|
||||
planner=planner,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _checkpointing_subprocess(
|
||||
pg_init_info: _ProcessGroupInitInfo,
|
||||
recv: mp.Queue,
|
||||
send: mp.Queue,
|
||||
) -> None:
|
||||
try:
|
||||
_init_logger(pg_init_info.global_rank)
|
||||
|
||||
# Setup environment variables for process group initialization.
|
||||
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
|
||||
os.environ["MASTER_ADDR"] = pg_init_info.tcp_store_master_addr
|
||||
os.environ["MASTER_PORT"] = str(pg_init_info.tcp_store_master_port)
|
||||
os.environ["LOCAL_RANK"] = str(pg_init_info.local_rank)
|
||||
os.environ["RANK"] = str(pg_init_info.global_rank)
|
||||
os.environ["WORLD_SIZE"] = str(pg_init_info.world_size)
|
||||
|
||||
logger.info(
|
||||
"Initializing dist.ProcessGroup in checkpoint background process"
|
||||
)
|
||||
# NOTE: GLOO backend is enforced here.
|
||||
dist.init_process_group(backend=dist.Backend.GLOO)
|
||||
dist.barrier()
|
||||
|
||||
logger.info("Checkpoint background process is running...")
|
||||
send.put(_CheckpointSaveProcessControlOpts.INIT_COMPLETE)
|
||||
|
||||
# Serving loop.
|
||||
while True:
|
||||
logger.info("Waiting for checkpoint save request...")
|
||||
obj = recv.get()
|
||||
if (
|
||||
isinstance(obj, _CheckpointSaveProcessControlOpts)
|
||||
and obj == _CheckpointSaveProcessControlOpts.TERMINATE
|
||||
):
|
||||
logger.info("Terminating the checkpoint background process.")
|
||||
return
|
||||
assert isinstance(obj, _AsyncCheckpointRequest)
|
||||
logger.info(
|
||||
f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004
|
||||
)
|
||||
|
||||
response = _AsyncCheckpointProcess._execute_save(
|
||||
obj.staged_state_dict,
|
||||
checkpoint_request_id=obj.checkpoint_request_id,
|
||||
storage_writer=obj.storage_writer,
|
||||
planner=obj.planner,
|
||||
)
|
||||
send.put(response)
|
||||
logger.info(
|
||||
f"Submitted checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
f"Checkpoint background process encountered an exception: {e}" # noqa: G004
|
||||
)
|
||||
send.put(e)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Checkpoint background process is shutting down...")
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
_CHECKPOINT_PROCESS: Optional[_AsyncCheckpointProcess] = None
|
||||
|
||||
|
||||
class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
|
||||
def __init__(self) -> None:
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
@staticmethod
|
||||
def _execute_save_impl(
|
||||
*,
|
||||
pg_init_info: Optional[_ProcessGroupInitInfo],
|
||||
staged_state_dict: STATE_DICT_TYPE,
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_writer: Optional[StorageWriter] = None,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Metadata:
|
||||
global _CHECKPOINT_PROCESS
|
||||
if _CHECKPOINT_PROCESS is None:
|
||||
assert pg_init_info is not None
|
||||
ckpt_kwargs = {}
|
||||
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
|
||||
ckpt_kwargs["checkpoint_id"] = ckpt_id
|
||||
ckpt_kwargs["process_group"] = process_group
|
||||
|
||||
@_dcp_method_logger(**ckpt_kwargs)
|
||||
def create_checkpoint_daemon_process() -> None:
|
||||
global _CHECKPOINT_PROCESS
|
||||
_CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info)
|
||||
|
||||
create_checkpoint_daemon_process()
|
||||
|
||||
assert _CHECKPOINT_PROCESS is not None
|
||||
return _CHECKPOINT_PROCESS.save(
|
||||
staged_state_dict=staged_state_dict,
|
||||
checkpoint_id=checkpoint_id,
|
||||
storage_writer=storage_writer,
|
||||
planner=planner,
|
||||
)
|
||||
|
||||
def execute_save(
|
||||
self,
|
||||
staged_state_dict: STATE_DICT_TYPE,
|
||||
*,
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_writer: Optional[StorageWriter] = None,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Future:
|
||||
"""
|
||||
NOTE:
|
||||
|
||||
- Checkpoint process is implemented as a daemon process.
|
||||
The AsyncCheckpointProcess' lifetime is tied to the lifetime of the
|
||||
main process (e.g. trainer process).
|
||||
|
||||
- The first call to execute_save_in_process() will initialize the checkpoint
|
||||
daemon process. Subsequent async checkpoint requests will not need process
|
||||
initialization. Therefore, the first async checkpoint request will take longer to complete.
|
||||
|
||||
- Process initialization can have significant overhead, dominated by latency for all ranks to spawn
|
||||
a background process + process group initialization in the background process.
|
||||
"""
|
||||
|
||||
global _CHECKPOINT_PROCESS
|
||||
pg_init_info: Optional[_ProcessGroupInitInfo] = None
|
||||
if _CHECKPOINT_PROCESS is None:
|
||||
# Find a free port on coordinator rank and broadcast
|
||||
# to all ranks.
|
||||
pg_init_info = _ProcessGroupInitInfo(process_group)
|
||||
|
||||
f: Future = self._executor.submit(
|
||||
self._execute_save_impl,
|
||||
pg_init_info=pg_init_info,
|
||||
staged_state_dict=staged_state_dict,
|
||||
checkpoint_id=checkpoint_id,
|
||||
storage_writer=storage_writer,
|
||||
planner=planner,
|
||||
)
|
||||
f.add_done_callback(lambda f: self._executor.shutdown(wait=False))
|
||||
|
||||
return f
|
||||
39
torch/distributed/checkpoint/_async_thread_executor.py
Normal file
39
torch/distributed/checkpoint/_async_thread_executor.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# pyre-strict
|
||||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor
|
||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||
from torch.distributed.checkpoint.planner import SavePlanner
|
||||
from torch.distributed.checkpoint.storage import StorageWriter
|
||||
|
||||
|
||||
class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
|
||||
def __init__(self) -> None:
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def execute_save(
|
||||
self,
|
||||
staged_state_dict: STATE_DICT_TYPE,
|
||||
*,
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_writer: Optional[StorageWriter] = None,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Future:
|
||||
from torch.distributed.checkpoint.state_dict_saver import save
|
||||
|
||||
f: Future = self._executor.submit(
|
||||
save,
|
||||
staged_state_dict,
|
||||
checkpoint_id=checkpoint_id,
|
||||
storage_writer=storage_writer,
|
||||
planner=planner,
|
||||
process_group=process_group,
|
||||
)
|
||||
f.add_done_callback(lambda f: self._executor.shutdown(wait=False))
|
||||
|
||||
return f
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
|
@ -9,6 +10,9 @@ import torch.distributed.c10d_logger as c10d_logger
|
|||
from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
global _dcp_logger
|
||||
|
|
@ -101,3 +105,14 @@ def _dcp_method_logger(
|
|||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _init_logger(rank: int):
|
||||
logger.setLevel(logging.INFO)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
f"[{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
|
|
|||
|
|
@ -3,13 +3,23 @@
|
|||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent.futures import Future
|
||||
from enum import Enum
|
||||
from typing import cast, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
|
||||
from torch.distributed.checkpoint._async_executor import ( # noqa: TC001
|
||||
_AsyncCheckpointExecutor,
|
||||
)
|
||||
from torch.distributed.checkpoint._async_process_executor import (
|
||||
_ProcessBasedAsyncCheckpointExecutor,
|
||||
)
|
||||
from torch.distributed.checkpoint._async_thread_executor import (
|
||||
_ThreadBasedAsyncCheckpointExecutor,
|
||||
)
|
||||
from torch.distributed.checkpoint._storage_utils import _storage_setup
|
||||
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
||||
from torch.distributed.checkpoint.logger import _dcp_method_logger
|
||||
|
|
@ -23,7 +33,14 @@ from torch.distributed.distributed_c10d import _get_default_group
|
|||
from .utils import _api_bc_check, _DistWrapper, _profile
|
||||
|
||||
|
||||
__all__ = ["save_state_dict", "save", "async_save"]
|
||||
__all__ = ["save_state_dict", "save", "async_save", "AsyncCheckpointerType"]
|
||||
|
||||
|
||||
class AsyncCheckpointerType(Enum):
|
||||
"""Enum for async checkpointer type."""
|
||||
|
||||
THREAD = "thread"
|
||||
PROCESS = "process"
|
||||
|
||||
|
||||
@deprecated(
|
||||
|
|
@ -173,6 +190,7 @@ def async_save(
|
|||
storage_writer: Optional[StorageWriter] = None,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD,
|
||||
) -> Future:
|
||||
"""Asynchronous version of ``save``. This code first de-stages the state_dict on to the
|
||||
staging storage (defaults to CPU memory), and then calls the `save` in a separate thread.
|
||||
|
|
@ -242,16 +260,19 @@ def async_save(
|
|||
staged_state_dict = _create_cpu_state_dict(state_dict)
|
||||
_copy_state_dict(state_dict, staged_state_dict, type_check=False)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
f: Future = executor.submit(
|
||||
save,
|
||||
executor: _AsyncCheckpointExecutor = (
|
||||
_ProcessBasedAsyncCheckpointExecutor()
|
||||
if async_checkpointer_type == AsyncCheckpointerType.PROCESS
|
||||
else _ThreadBasedAsyncCheckpointExecutor()
|
||||
)
|
||||
|
||||
f: Future = executor.execute_save(
|
||||
staged_state_dict,
|
||||
checkpoint_id=checkpoint_id,
|
||||
storage_writer=storage_writer,
|
||||
planner=planner,
|
||||
process_group=process_group,
|
||||
)
|
||||
f.add_done_callback(lambda f: executor.shutdown(wait=False))
|
||||
|
||||
if (
|
||||
isinstance(storage_writer, AsyncStager)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user