Add async checkpointing impl to experimental checkpointer and add a builder API (#156927)

1. Adds an AsyncCheckpointer with out-of-process checkpointing and state_dict_stager with shared memory, pinned memory and Zero Overhead Support.

2. Adds two conveinient functions to create sync/async checkpointers

Differential Revision: [D77336833](https://our.internmc.facebook.com/intern/diff/D77336833/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156927
Approved by: https://github.com/pradeepfn
This commit is contained in:
Teja 2025-07-03 10:35:15 -07:00 committed by PyTorch MergeBot
parent 7081b8233a
commit dd3e7170c2
12 changed files with 1890 additions and 0 deletions

View File

@ -0,0 +1,165 @@
# Owner(s): ["oncall: distributed checkpointing"]
import os
import shutil
import tempfile
import torch
from torch.distributed.checkpoint._experimental.barriers import BarrierConfig
from torch.distributed.checkpoint._experimental.builder import (
make_async_checkpointer,
make_sync_checkpointer,
)
from torch.distributed.checkpoint._experimental.checkpointer import (
AsyncCheckpointer,
SyncCheckpointer,
)
from torch.distributed.checkpoint._experimental.config import CheckpointerConfig
from torch.distributed.checkpoint._experimental.staging import CheckpointStagerConfig
from torch.distributed.checkpoint._experimental.types import RankInfo
from torch.testing._internal.common_utils import run_tests, TestCase
class TestMakeCheckpointer(TestCase):
def setUp(self) -> None:
# Create a temporary directory for checkpoints
self.temp_dir = tempfile.mkdtemp()
# Create real objects for testing
self.rank_info = RankInfo(
global_world_size=1,
global_rank=0,
)
# Create a test state dictionary
self.state_dict = {
"model": torch.nn.Linear(10, 5).state_dict(),
"optimizer": {"param_groups": [{"lr": 0.01}]},
"epoch": 5,
"step": 1000,
}
def tearDown(self) -> None:
# Clean up the temporary directory
shutil.rmtree(self.temp_dir)
def test_make_sync_checkpointer(self) -> None:
"""Test creating a synchronous checkpointer using make_sync_checkpointer."""
# Create sync checkpointer using factory function with no barrier
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info)
# Verify it's a SyncCheckpointer instance
self.assertIsInstance(checkpointer, SyncCheckpointer)
# Test that it works for sync operations
checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_sync")
result = checkpointer.save(self.state_dict, checkpoint_path)
self.assertIsNone(result) # Sync mode returns None
# Verify checkpoint was created
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(checkpoint_file))
# Test loading
loaded_state_dict = checkpointer.load(checkpoint_path)
self.assertEqual(loaded_state_dict["epoch"], 5)
def test_make_sync_checkpointer_with_config_first(self) -> None:
"""Test creating a synchronous checkpointer with config as first parameter."""
# Create sync checkpointer with config as first parameter
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info)
# Verify it's a SyncCheckpointer instance
self.assertIsInstance(checkpointer, SyncCheckpointer)
# Test that it works for sync operations
checkpoint_path = os.path.join(
self.temp_dir, "checkpoint_factory_sync_config_first"
)
result = checkpointer.save(self.state_dict, checkpoint_path)
self.assertIsNone(result) # Sync mode returns None
# Verify checkpoint was created
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(checkpoint_file))
def test_make_sync_checkpointer_with_custom_config(self) -> None:
"""Test creating a synchronous checkpointer with a custom config."""
# Create a custom config with no barrier
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
# Create sync checkpointer with the custom config
checkpointer = make_sync_checkpointer(rank_info=self.rank_info, config=config)
# Verify it's a SyncCheckpointer instance
self.assertIsInstance(checkpointer, SyncCheckpointer)
# Test that it works for sync operations
checkpoint_path = os.path.join(
self.temp_dir, "checkpoint_factory_sync_custom_config"
)
result = checkpointer.save(self.state_dict, checkpoint_path)
self.assertIsNone(result) # Sync mode returns None
# Verify checkpoint was created
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(checkpoint_file))
# Test loading
loaded_state_dict = checkpointer.load(checkpoint_path)
self.assertEqual(loaded_state_dict["epoch"], 5)
def test_make_async_checkpointer(self) -> None:
"""Test creating an asynchronous checkpointer using make_async_checkpointer."""
# Create async checkpointer using factory function with default parameters
config: CheckpointerConfig = CheckpointerConfig()
config.staging_config = CheckpointStagerConfig(
use_cuda_non_blocking_copy=torch.cuda.is_available(),
use_pinned_memory=torch.cuda.is_available(),
)
checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info)
try:
# Verify it's an AsyncCheckpointer instance
self.assertIsInstance(checkpointer, AsyncCheckpointer)
# Test that it works for async operations
checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_async")
stage_future, write_future = checkpointer.save(
self.state_dict, checkpoint_path
)
# Verify futures are returned
self.assertIsNotNone(stage_future)
self.assertIsNotNone(write_future)
# Wait for completion
stage_future.result()
write_future.result()
# Verify checkpoint was created
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(checkpoint_file))
# Test loading
loaded_state_dict = checkpointer.load(checkpoint_path)
self.assertEqual(loaded_state_dict["epoch"], 5)
finally:
# Clean up
checkpointer.close()
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,465 @@
# Owner(s): ["oncall: distributed checkpointing"]
import os
import tempfile
import time
from concurrent.futures import Future
from typing import Any
import torch
from torch.distributed.checkpoint._experimental.checkpoint_process import (
CheckpointProcess,
CheckpointProcessConfig,
RequestType,
WorkerRequest,
WorkerResponse,
)
from torch.distributed.checkpoint._experimental.checkpoint_writer import (
CheckpointWriter,
CheckpointWriterConfig,
)
from torch.distributed.checkpoint._experimental.types import RankInfo
from torch.testing._internal.common_utils import run_tests, TestCase
def subprocess_init_fn(name: str, parent_pid: int) -> None:
"""Initialize the subprocess with some basic checks.
This is similar to the subprocess_init_routine in checkpointing_test.py.
"""
assert name == "test-checkpointer", f"Unexpected subprocess name: {name}"
assert os.getpid() != parent_pid, "This was supposed to run in a different process"
assert os.getppid() == parent_pid, (
"This was supposed to run as a child to main process"
)
def failing_subprocess_init_fn(name: str, parent_pid: int) -> None:
"""Initialize function that raises an exception."""
# Acknowledge parameters to avoid unused variable warnings
_ = name
_ = parent_pid
raise RuntimeError("Subprocess initialization failed")
def timedout_subprocess_init_fn(**kwargs: Any) -> None:
# Acknowledge parameters to avoid unused variable warnings
_ = kwargs
time.sleep(3) # Simulate a long initialization
def ckpt_writer_init_fn(**kwargs: Any) -> CheckpointWriter:
"""Initialize a CheckpointWriter in the subprocess.
This function is called in the subprocess to create a CheckpointWriter instance.
It's important that this function is defined at the module level so it can be pickled.
"""
return CheckpointWriter(
config=kwargs.get("config"),
rank_info=kwargs.get("rank_info"),
)
def failing_ckpt_writer_init_fn(**kwargs: Any) -> CheckpointWriter:
"""Initialize function that raises an exception."""
# Acknowledge parameters to avoid unused variable warnings
_ = kwargs
raise RuntimeError("CheckpointWriter initialization failed")
def shared_tensor_verifier_init_fn(**kwargs: Any) -> CheckpointWriter:
"""Initialize a CheckpointWriter that verifies shared memory tensors."""
class SharedTensorVerifier(CheckpointWriter):
def __init__(self, config=None, rank_info=None, **init_kwargs):
# Acknowledge unused kwargs to avoid linting warnings
_ = init_kwargs
super().__init__(
config=config or CheckpointWriterConfig(),
rank_info=rank_info,
barrier=None,
commit_hook=None,
)
def write(self, state_dict, path, **__):
# Acknowledge parameters to avoid unused variable warnings
_ = path
# Verify shared memory tensor behavior directly with assertions
if "shared_tensor" in state_dict:
shared_tensor = state_dict["shared_tensor"]
# Critical assertion: shared tensor should remain in shared memory in subprocess
assert shared_tensor.is_shared(), (
"Shared tensor should be in shared memory in subprocess"
)
shared_tensor[0] = 42.0
if "regular_tensor" in state_dict:
# Note: ForkingPickler moves regular tensors to shared memory during IPC - this is acceptable
assert state_dict["regular_tensor"].is_shared(), (
"Regular tensor should also be in shared memory in subprocess"
)
return None
verifier = SharedTensorVerifier(
config=kwargs.get("config"),
rank_info=kwargs.get("rank_info"),
)
return verifier
class TestRequestTypes(TestCase):
"""Test the request/response data structures."""
def test_request_type_enum(self) -> None:
"""Test RequestType enum values."""
self.assertEqual(RequestType.PING.value, "ping")
self.assertEqual(RequestType.WRITE_CHECKPOINT.value, "write_checkpoint")
self.assertEqual(RequestType.TERMINATE_PROCESS.value, "exit")
def test_worker_request(self) -> None:
"""Test WorkerRequest dataclass."""
request = WorkerRequest(request_type=RequestType.PING, payload={"test": "data"})
self.assertEqual(request.request_type, RequestType.PING)
self.assertEqual(request.payload["test"], "data")
def test_worker_response(self) -> None:
"""Test WorkerResponse dataclass."""
response = WorkerResponse(
request_type=RequestType.PING,
success=True,
error_msg=None,
payload={"result": "success"},
)
self.assertEqual(response.request_type, RequestType.PING)
self.assertTrue(response.success)
self.assertIsNone(response.error_msg)
self.assertEqual(response.payload["result"], "success")
class TestCheckpointProcessConfig(TestCase):
"""Test CheckpointProcessConfig configuration."""
def test_default_options(self) -> None:
"""Test default CheckpointProcessConfig."""
options = CheckpointProcessConfig()
# Test default values
self.assertEqual(options.subprocess_init_timeout_secs, 30)
self.assertEqual(options.subprocess_shutdown_timeout_secs, 60)
def test_custom_options(self) -> None:
"""Test custom CheckpointProcessConfig."""
options = CheckpointProcessConfig(
subprocess_init_timeout_secs=10, subprocess_shutdown_timeout_secs=30
)
self.assertEqual(options.subprocess_init_timeout_secs, 10)
self.assertEqual(options.subprocess_shutdown_timeout_secs, 30)
class TestCheckpointProcess(TestCase):
def setUp(self) -> None:
"""Set up common test fixtures."""
self.rank_info = RankInfo(
global_world_size=1,
global_rank=0,
)
self.writer_config = CheckpointWriterConfig()
self.test_state_dict = {
"model": torch.nn.Linear(10, 5).state_dict(),
"optimizer": {"param_groups": [{"lr": 0.01}]},
"epoch": 5,
"step": 1000,
}
def _create_checkpoint_process(
self,
subprocess_init_fn_override=None,
subprocess_init_args_override=None,
writer_init_fn_override=None,
subprocess_init_timeout_secs=30,
):
"""Helper to create CheckpointProcess."""
config = CheckpointProcessConfig(
subprocess_init_timeout_secs=subprocess_init_timeout_secs,
)
return CheckpointProcess(
rank_info=self.rank_info,
config=config,
subprocess_init_fn=subprocess_init_fn_override or subprocess_init_fn,
subprocess_init_args=subprocess_init_args_override
or (
"test-checkpointer",
os.getpid(),
),
checkpoint_writer_init_fn=writer_init_fn_override or ckpt_writer_init_fn,
checkpoint_writer_init_args={
"config": self.writer_config,
"rank_info": self.rank_info,
},
)
def test_checkpoint_process_initialization(self) -> None:
"""Test that CheckpointProcess initializes and closes correctly."""
checkpoint_process = self._create_checkpoint_process()
# Wait for the process creation future to complete
checkpoint_process.process_creation_future.result()
# Verify process is alive
self.assertTrue(checkpoint_process.process.processes[0].is_alive())
checkpoint_process.close()
# Verify process is terminated
self.assertFalse(checkpoint_process.process.processes[0].is_alive())
def test_checkpoint_write_sync_state_dict(self) -> None:
"""Test writing a checkpoint with synchronous state dict."""
checkpoint_process = self._create_checkpoint_process()
# Wait for initialization
checkpoint_process.process_creation_future.result()
# Create a temporary directory for the checkpoint
with tempfile.TemporaryDirectory() as temp_dir:
checkpoint_path = os.path.join(temp_dir, "test_checkpoint")
# Write checkpoint
future = checkpoint_process.write(self.test_state_dict, checkpoint_path)
# Verify future is returned
self.assertIsInstance(future, Future)
# Wait for completion
future.result()
# Verify checkpoint file was created
expected_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(expected_file))
# Verify checkpoint content
loaded_state_dict = torch.load(expected_file)
self.assertIn("model", loaded_state_dict)
self.assertIn("optimizer", loaded_state_dict)
self.assertEqual(loaded_state_dict["epoch"], 5)
self.assertEqual(loaded_state_dict["step"], 1000)
checkpoint_process.close()
def test_checkpoint_write_future_state_dict(self) -> None:
"""Test writing a checkpoint with Future state dict."""
checkpoint_process = self._create_checkpoint_process()
# Wait for initialization
checkpoint_process.process_creation_future.result()
# Create a Future that resolves to the state dict
from concurrent.futures import ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=1)
def get_state_dict():
time.sleep(0.1) # Simulate some processing time
return self.test_state_dict
future_state_dict = executor.submit(get_state_dict)
# Create a temporary directory for the checkpoint
with tempfile.TemporaryDirectory() as temp_dir:
checkpoint_path = os.path.join(temp_dir, "test_checkpoint")
# Write checkpoint with Future state dict
write_future = checkpoint_process.write(future_state_dict, checkpoint_path)
# Wait for completion
write_future.result()
# Verify checkpoint file was created
expected_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(expected_file))
executor.shutdown(wait=True)
checkpoint_process.close()
def test_checkpoint_write_with_kwargs(self) -> None:
"""Test checkpoint writing with additional kwargs."""
checkpoint_process = self._create_checkpoint_process()
# Wait for initialization
checkpoint_process.process_creation_future.result()
with tempfile.TemporaryDirectory() as temp_dir:
checkpoint_path = os.path.join(temp_dir, "test_checkpoint")
# Write checkpoint with kwargs
future = checkpoint_process.write(
self.test_state_dict,
checkpoint_path,
custom_arg="test_value",
another_arg=42,
)
# Wait for completion
future.result()
# Verify checkpoint was created
expected_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(expected_file))
checkpoint_process.close()
def test_subprocess_initialization_timeout(self) -> None:
"""Test subprocess initialization timeout."""
# Create checkpoint process with a very short timeout by mocking the initialization
checkpoint_process = self._create_checkpoint_process(
subprocess_init_fn_override=timedout_subprocess_init_fn,
subprocess_init_timeout_secs=1,
)
# This should timeout
with self.assertRaises(TimeoutError) as cm:
checkpoint_process.process_creation_future.result()
self.assertIn("Timed out", str(cm.exception))
def test_subprocess_initialization_failure(self) -> None:
"""Test subprocess initialization failure."""
checkpoint_process = self._create_checkpoint_process(
subprocess_init_fn_override=failing_subprocess_init_fn
)
# The subprocess should fail to initialize
# We expect this to raise an exception when we try to use it
with self.assertRaises(RuntimeError):
checkpoint_process.process_creation_future.result()
def test_graceful_termination(self) -> None:
"""Test graceful termination of subprocess."""
checkpoint_process = self._create_checkpoint_process()
checkpoint_process.process_creation_future.result()
self.assertTrue(checkpoint_process.process.processes[0].is_alive())
checkpoint_process.close()
self.assertFalse(checkpoint_process.process.processes[0].is_alive())
def test_forced_termination(self) -> None:
"""Test forced termination when graceful termination fails."""
checkpoint_process = self._create_checkpoint_process()
# Wait for initialization
checkpoint_process.process_creation_future.result()
# Mock the join method to simulate timeout
def mock_join(timeout=None):
# Acknowledge timeout parameter to avoid unused variable warning
_ = timeout
return False # Simulate timeout
checkpoint_process.process.join = mock_join
# This should trigger forced termination
checkpoint_process.close()
# Process should still be terminated (killed)
# Note: This test might be flaky depending on timing
def test_communication_error_handling(self):
"""Test handling of communication errors."""
checkpoint_process = self._create_checkpoint_process()
# Wait for initialization
checkpoint_process.process_creation_future.result()
# Close the pipe to simulate communication failure
checkpoint_process._parent_end.close()
# Attempting to write should raise an error
with self.assertRaises(RuntimeError) as cm:
future = checkpoint_process.write(self.test_state_dict, "/tmp/test")
future.result()
self.assertIn("Child process terminated unexpectedly", str(cm.exception))
def test_shared_memory_tensor_ipc(self):
"""Test that shared memory tensors are backed by the same memory across processes."""
checkpoint_process = self._create_checkpoint_process(
writer_init_fn_override=shared_tensor_verifier_init_fn,
)
checkpoint_process.process_creation_future.result()
# Create tensors and put them in shared memory
shared_tensor = torch.randn(100, 100)
shared_tensor.share_memory_()
shared_tensor_data_ptr = shared_tensor.data_ptr()
regular_tensor = torch.randn(50, 50)
# Don't put regular tensor in shared memory for comparison
# Verify initial shared memory status
self.assertTrue(
shared_tensor.is_shared(), "Shared tensor should be in shared memory"
)
self.assertFalse(
regular_tensor.is_shared(), "Regular tensor should not be in shared memory"
)
# Create state dict with mixed tensor types
test_state_dict = {
"shared_tensor": shared_tensor,
"regular_tensor": regular_tensor,
}
# Write to subprocess - the SharedTensorVerifier will:
# 1. Verify the tensor is still in shared memory
# 2. Check the marker value (42.0) to confirm same memory
# 3. Modify specific positions to prove same memory access
future = checkpoint_process.write(test_state_dict, "")
try:
result = (
future.result()
) # This will raise an exception if the subprocess assertions fail
self.assertIsNone(result) # SharedTensorVerifier returns None on success
except Exception as e:
self.fail(f"Subprocess assertions failed: {e}")
# assert shared tensor is still in same shared memory
self.assertEqual(
shared_tensor_data_ptr,
shared_tensor.data_ptr(),
"Shared tensor should still be in same shared memory",
)
self.assertTrue(
shared_tensor.is_shared(), "Shared tensor should still be in shared memory"
)
# CRITICAL TEST: Verify that modifications made by subprocess are visible in main process
# This definitively proves that both processes access the same memory
self.assertAlmostEqual(
shared_tensor[0][0],
42.0,
places=6,
msg=f"Expected subprocess signature 42.0, got {shared_tensor[0]}. "
f"Shared memory not working - subprocess modifications not visible!",
)
checkpoint_process.close()
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,216 @@
# Owner(s): ["oncall: distributed checkpointing"]
from concurrent.futures import Future
import torch
from torch.distributed.checkpoint._experimental.staging import (
CheckpointStagerConfig,
DefaultStager,
)
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
class TestDefaultStager(TestCase):
def setUp(self) -> None:
# Create a test state dictionary with various data types
self.state_dict = {
"model": torch.nn.Linear(10, 5).state_dict(),
"optimizer": {"param_groups": [{"lr": 0.01}]},
"epoch": 5,
"step": 1000,
"tensor": torch.randn(3, 4),
"nested": {"inner_tensor": torch.ones(2, 2), "inner_value": 42},
}
@requires_cuda
def test_sync_staging(self) -> None:
"""Test synchronous staging."""
options = CheckpointStagerConfig(use_async_staging=False)
stager = DefaultStager(options)
# Stage the state dict
staged_dict = stager.stage(self.state_dict)
# Verify that a state dict is returned (not a Future)
self.assertIsInstance(staged_dict, dict)
# Verify the staged state dictionary
self.assertIn("model", staged_dict)
self.assertIn("optimizer", staged_dict)
self.assertEqual(staged_dict["epoch"], 5)
self.assertEqual(staged_dict["step"], 1000)
self.assertIn("tensor", staged_dict)
self.assertIn("nested", staged_dict)
# Clean up
stager.close()
@requires_cuda
def test_async_staging(self) -> None:
"""Test asynchronous staging."""
options = CheckpointStagerConfig(use_async_staging=True)
stager = DefaultStager(options)
# Stage the state dict
result = stager.stage(self.state_dict)
# Verify that a Future is returned
self.assertIsInstance(result, Future)
# Wait for the Future to complete
staged_dict = result.result()
# Verify the staged state dictionary
self.assertIn("model", staged_dict)
self.assertIn("optimizer", staged_dict)
self.assertEqual(staged_dict["epoch"], 5)
self.assertEqual(staged_dict["step"], 1000)
# Clean up
stager.close()
def test_cuda_non_blocking_without_cuda(self) -> None:
"""Test that non-blocking copy fails when CUDA is not available."""
if torch.cuda.is_available():
self.skipTest("CUDA is available, cannot test CUDA unavailable scenario")
options = CheckpointStagerConfig(use_cuda_non_blocking_copy=True)
with self.assertRaises(AssertionError):
DefaultStager(options)
def test_different_option_combinations(self) -> None:
"""Test various combinations of staging options."""
test_cases = [
# All disabled
CheckpointStagerConfig(
use_pinned_memory=False,
use_shared_memory=False,
use_async_staging=False,
use_cuda_non_blocking_copy=False,
),
# Only pinned memory
CheckpointStagerConfig(
use_pinned_memory=True,
use_shared_memory=False,
use_async_staging=False,
use_cuda_non_blocking_copy=False,
),
# Only shared memory
CheckpointStagerConfig(
use_pinned_memory=False,
use_shared_memory=True,
use_async_staging=False,
use_cuda_non_blocking_copy=False,
),
]
if torch.cuda.is_available():
# Only async staging
test_cases.append(
CheckpointStagerConfig(
use_pinned_memory=torch.cuda.is_available(),
use_shared_memory=False,
use_async_staging=True,
use_cuda_non_blocking_copy=False,
)
)
# Only CUDA non-blocking copy
test_cases.append(
CheckpointStagerConfig(
use_pinned_memory=torch.cuda.is_available(),
use_shared_memory=False,
use_async_staging=False,
use_cuda_non_blocking_copy=torch.cuda.is_available(),
)
)
for options in test_cases:
with self.subTest(options=options):
stager = DefaultStager(options)
# Test staging works with these options
if options.use_async_staging and torch.cuda.is_available():
result = stager.stage(self.state_dict)
self.assertIsInstance(result, Future)
staged_dict = result.result()
else:
staged_dict = stager.stage(self.state_dict)
self.assertIsInstance(staged_dict, dict)
self.assertIn("model", staged_dict)
stager.close()
@requires_cuda
def test_cuda_tensors_staging(self) -> None:
"""Test staging with CUDA tensors."""
# Create state dict with CUDA tensors
cuda_state_dict = {
"cuda_tensor": torch.randn(3, 4).cuda(),
"cpu_tensor": torch.randn(2, 3),
"mixed_model": {
"weight": torch.randn(5, 5).cuda(),
"bias": torch.randn(5).cuda(),
},
}
options = CheckpointStagerConfig(use_async_staging=False)
stager = DefaultStager(options)
staged_dict = stager.stage(cuda_state_dict)
assert isinstance(staged_dict, dict)
# Verify tensors are staged (should be moved to CPU)
self.assertIn("cuda_tensor", staged_dict)
self.assertIn("cpu_tensor", staged_dict)
self.assertIn("mixed_model", staged_dict)
stager.close()
@requires_cuda
def test_resource_cleanup(self) -> None:
"""Test that resources are properly cleaned up."""
options = CheckpointStagerConfig(use_async_staging=False)
stager = DefaultStager(options)
# Verify initial state
self.assertIsNotNone(stager._state_dict_stager)
# Close and verify cleanup
stager.close()
def test_multiple_staging_operations(self) -> None:
"""Test multiple staging operations with the same stager."""
options = CheckpointStagerConfig(
use_async_staging=False,
use_pinned_memory=torch.cuda.is_available(),
use_shared_memory=False,
use_cuda_non_blocking_copy=torch.cuda.is_available(),
)
stager = DefaultStager(options)
# Stage multiple different state dicts
state_dicts = [
{"model1": torch.nn.Linear(5, 3).state_dict()},
{"model2": torch.nn.Conv2d(3, 16, 3).state_dict()},
{"optimizer": {"lr": 0.001, "momentum": 0.9}},
]
staged_results = []
for state_dict in state_dicts:
staged_dict = stager.stage(state_dict)
staged_results.append(staged_dict)
# Verify all staging operations succeeded
self.assertEqual(len(staged_results), 3)
for i, result in enumerate(staged_results):
self.assertIsInstance(result, dict)
# Verify the result contains the expected keys
for key in state_dicts[i].keys():
self.assertIn(key, result)
stager.close()
if __name__ == "__main__":
run_tests()

View File

@ -19,9 +19,14 @@ from .barriers import (
create_barrier_from_config,
TCPStoreBarrier,
)
from .builder import make_async_checkpointer, make_sync_checkpointer
from .checkpoint_reader import CheckpointReader
from .checkpoint_writer import CheckpointWriter, CheckpointWriterConfig, WriterHook
from .checkpointer import AsyncCheckpointer, Checkpointer, SyncCheckpointer
from .config import CheckpointerConfig
from .staging import CheckpointStager, CheckpointStagerConfig, DefaultStager
from .types import RankInfo, STATE_DICT
from .utils import wrap_future
__all__ = [
@ -31,8 +36,18 @@ __all__ = [
"CheckpointWriter",
"CheckpointWriterConfig",
"WriterHook",
"Checkpointer",
"SyncCheckpointer",
"AsyncCheckpointer",
"CheckpointerConfig",
"BarrierConfig",
"create_barrier_from_config",
"CheckpointStager",
"CheckpointStagerConfig",
"DefaultStager",
"RankInfo",
"STATE_DICT",
"wrap_future",
"make_sync_checkpointer",
"make_async_checkpointer",
]

View File

@ -0,0 +1,173 @@
"""
Factory functions for creating checkpointer instances with sensible defaults.
This module provides high-level factory functions that simplify the creation
of checkpointer instances by automatically handling component initialization
and configuration with reasonable defaults.
"""
from typing import Any, Callable, Optional
import torch.distributed as dist
from .barriers import create_barrier_from_config
from .checkpoint_process import CheckpointProcess
from .checkpoint_reader import CheckpointReader
from .checkpoint_writer import CheckpointWriter, CheckpointWriterConfig, WriterHook
from .checkpointer import AsyncCheckpointer, SyncCheckpointer
from .config import CheckpointerConfig
from .staging import DefaultStager
from .types import RankInfo
def _get_default_rank_info() -> RankInfo:
"""
Get default rank information from the current distributed environment.
Returns:
RankInfo: Rank information from the default process group if initialized,
otherwise single-rank fallback.
"""
if dist.is_initialized():
return RankInfo(
global_world_size=dist.get_world_size(),
global_rank=dist.get_rank(),
)
else:
# Single-rank fallback
return RankInfo(global_world_size=1, global_rank=0)
def default_subprocess_init_fn(*_: Any) -> None:
"""Default subprocess initialization function (no-op)."""
def default_writer_init_fn(rank_info: RankInfo) -> CheckpointWriter:
"""Default checkpoint writer initialization function."""
return CheckpointWriter(
config=CheckpointWriterConfig(),
rank_info=rank_info,
)
def make_sync_checkpointer(
config: CheckpointerConfig = CheckpointerConfig(),
rank_info: Optional[RankInfo] = None,
commit_hook: Optional[WriterHook] = None,
) -> SyncCheckpointer:
"""
Factory function to create a SyncCheckpointer instance with sensible defaults.
This function creates a synchronous checkpointer with default components, automatically
detecting rank information from the default process group if available, and using the
provided component configurations.
Args:
config: CheckpointerConfig containing component-specific configurations
(writer_config, staging_config, process_config). Defaults to CheckpointerConfig().
rank_info: RankInfo for distributed training. Defaults to auto-detection from
the default PyTorch distributed process group if initialized, otherwise
falls back to single-rank (world_size=1, rank=0).
commit_hook: Optional hook for custom actions before and after checkpoint commits.
Returns:
SyncCheckpointer: A configured synchronous checkpointer instance.
Examples:
# Simplest usage - auto-detect rank, default config
checkpointer = make_sync_checkpointer()
# Explicit rank configuration
checkpointer = make_sync_checkpointer(
rank_info=RankInfo(global_world_size=4, global_rank=0)
)
# Disable barrier
from .barriers import BarrierConfig
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
checkpointer = make_sync_checkpointer(config=config)
"""
if rank_info is None:
rank_info = _get_default_rank_info()
reader = CheckpointReader(
rank_info=rank_info,
)
barrier = create_barrier_from_config(config.barrier_config)
writer = CheckpointWriter(
config=config.writer_config,
rank_info=rank_info,
barrier=barrier,
commit_hook=commit_hook,
)
return SyncCheckpointer(
writer=writer,
reader=reader,
)
def make_async_checkpointer(
config: CheckpointerConfig = CheckpointerConfig(),
rank_info: Optional[RankInfo] = None,
subprocess_init_fn: Callable[..., None] = default_subprocess_init_fn,
subprocess_init_args: tuple[Any, ...] = (),
checkpoint_writer_init_fn: Callable[..., CheckpointWriter] = default_writer_init_fn,
checkpoint_writer_init_args: Optional[dict[str, Any]] = None,
) -> AsyncCheckpointer:
"""
Factory function to create an AsyncCheckpointer instance with sensible defaults.
This function creates an asynchronous checkpointer using the provided configuration,
automatically detecting rank information if not provided.
Args:
config: CheckpointerConfig containing component-specific configurations.
rank_info: RankInfo for distributed training. Defaults to auto-detection.
subprocess_init_fn: Function to initialize the subprocess. Defaults to no-op.
subprocess_init_args: Arguments to pass to subprocess_init_fn.
checkpoint_writer_init_fn: Function to create CheckpointWriter instance.
checkpoint_writer_init_args: Arguments to pass to checkpoint_writer_init_fn.
Returns:
AsyncCheckpointer: A configured asynchronous checkpointer instance.
Examples:
# Create with default config
checkpointer = make_async_checkpointer()
# Create with custom init functions
checkpointer = make_async_checkpointer(
subprocess_init_fn=my_subprocess_init_fn,
checkpoint_writer_init_fn=my_writer_init_fn
)
"""
if rank_info is None:
rank_info = _get_default_rank_info()
reader = CheckpointReader(
rank_info=rank_info,
)
checkpoint_stager = DefaultStager(
config=config.staging_config,
)
checkpoint_writer_init_args = checkpoint_writer_init_args or {}
checkpoint_process = CheckpointProcess(
rank_info=rank_info,
config=config.process_config,
subprocess_init_fn=subprocess_init_fn,
subprocess_init_args=subprocess_init_args,
checkpoint_writer_init_fn=checkpoint_writer_init_fn,
checkpoint_writer_init_args=checkpoint_writer_init_args,
)
return AsyncCheckpointer(
checkpoint_stager=checkpoint_stager,
checkpoint_process=checkpoint_process,
reader=reader,
)

View File

@ -0,0 +1,361 @@
import logging
import os
import traceback
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from multiprocessing.connection import Connection
from typing import Any, Callable, Optional, Union
import torch.multiprocessing as mp
from torch.multiprocessing.spawn import ProcessExitedException
from .checkpoint_writer import CheckpointWriter
from .types import RankInfo, STATE_DICT
logger = logging.getLogger(__name__)
@dataclass
class CheckpointProcessConfig:
"""
Configuration options for the CheckpointProcess.
This class provides configuration options for the checkpoint process,
including initialization functions, timeouts, and writer configuration.
Attributes:
subprocess_init_timeout_secs: Maximum time in seconds to wait for subprocess initialization.
subprocess_shutdown_timeout_secs: Maximum time in seconds to wait for subprocess shutdown.
"""
subprocess_init_timeout_secs: int = 30
subprocess_shutdown_timeout_secs: int = 60
class RequestType(Enum):
PING = "ping"
WRITE_CHECKPOINT = "write_checkpoint"
TERMINATE_PROCESS = "exit"
@dataclass
class WorkerRequest:
"""
A dataclass for storing the command to be sent to the worker process.
Note: This relies on pickling to send the command to the worker process. Handle
backward compatibility accordingly.
"""
request_type: RequestType
payload: dict[str, Any]
@dataclass
class WorkerResponse:
request_type: RequestType
success: bool
error_msg: Optional[str] = None
payload: Optional[dict[str, Any]] = None
class CheckpointProcess:
"""
A checkpoint writer that writes checkpoints to a remote process.
"""
def __init__(
self,
rank_info: RankInfo,
config: CheckpointProcessConfig,
subprocess_init_fn: Callable[[Any], None],
subprocess_init_args: tuple[Any, ...],
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
checkpoint_writer_init_args: dict[str, Any],
):
self._executor = ThreadPoolExecutor(max_workers=1)
self._rank_info = rank_info
self._config = config
self._subprocess_init_fn = subprocess_init_fn
self._subprocess_init_args = subprocess_init_args
self._checkpoint_writer_init_fn = checkpoint_writer_init_fn
self._checkpoint_writer_init_args = checkpoint_writer_init_args
self.process = None
self._parent_end: Optional[Connection] = None
self._child_end: Optional[Connection] = None
self.process_creation_future = self._executor.submit(
self._create_subprocess,
config,
)
def _create_subprocess(
self,
config: CheckpointProcessConfig,
) -> None:
logger.info(
"Creating checkpoint subprocess for rank %d", self._rank_info.global_rank
)
spawn_context = mp.get_context("spawn")
self._parent_end, child_end = spawn_context.Pipe()
# Known workaround for https://github.com/pytorch/pytorch/issues/37377
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
logger.debug("Spawning subprocess for rank_info=%s", self._rank_info)
self.process = mp.spawn(
fn=CheckpointProcess._subprocess,
args=(
self._rank_info,
child_end,
self._subprocess_init_fn,
self._subprocess_init_args,
self._checkpoint_writer_init_fn,
self._checkpoint_writer_init_args,
),
nprocs=1,
join=False,
daemon=True,
)
# close the child end of the pipe so recv on it will fail
# fast when the child process is terminated unexpectedly.
child_end.close()
self._send(
request_type=RequestType.PING,
payload={},
)
logger.debug(
"Waiting for checkpoint subprocess to initialize (timeout: %ds)",
config.subprocess_init_timeout_secs,
)
# wait for the timeout or a response from subprocess
assert self._parent_end is not None, "Parent end of pipe should be initialized"
if not self._parent_end.poll(timeout=config.subprocess_init_timeout_secs):
msg = f"Timed out after {config.subprocess_init_timeout_secs}s waiting for checkpoint subprocess to initialize"
logger.error(msg)
raise TimeoutError(msg)
self._recv()
logger.info("Checkpoint subprocess initialized successfully")
@staticmethod
def _subprocess(
sub_rank: int,
rank_info: RankInfo,
parent_pipe: Connection,
subprocess_init_fn: Callable[[Any], None],
subprocess_init_args: tuple[Any, ...],
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
checkpoint_writer_init_args: dict[str, Any],
) -> None:
logger.debug(
"Checkpoint subprocess started for rank %d/%d (PID: %d)",
rank_info.global_rank,
rank_info.global_world_size,
os.getpid(),
)
assert sub_rank == 0, "We need only one checkpointer per parent training"
request = WorkerRequest(request_type=RequestType.PING, payload={})
try:
# Calling initialize callback, so we can perform app-specific initialization of the subprocess.
subprocess_init_fn(*subprocess_init_args)
# Initialize checkpoint writer - automatically include rank_info in init_args
writer_init_args = dict(checkpoint_writer_init_args)
if "rank_info" not in writer_init_args:
writer_init_args["rank_info"] = rank_info
checkpoint_writer = checkpoint_writer_init_fn(**writer_init_args)
while True:
request = parent_pipe.recv()
if request.request_type == RequestType.PING:
parent_pipe.send(
WorkerResponse(request_type=RequestType.PING, success=True)
)
elif request.request_type == RequestType.WRITE_CHECKPOINT:
path = request.payload["path"]
logger.info("Writing checkpoint to %s", path)
checkpoint_writer.write(
state_dict=request.payload["state_dict"],
path=path,
**request.payload["kwargs"],
)
logger.info("Checkpoint written successfully to %s", path)
parent_pipe.send(
WorkerResponse(RequestType.WRITE_CHECKPOINT, success=True)
)
elif request.request_type == RequestType.TERMINATE_PROCESS:
logger.debug("Received termination request.")
parent_pipe.send(
WorkerResponse(RequestType.TERMINATE_PROCESS, success=True)
)
logger.info("Subprocess terminated gracefully")
break
else:
error_msg = f"Unknown request type: {request.request_type}"
logger.error(error_msg)
raise ValueError(error_msg)
except Exception as e:
error_text = traceback.format_exc()
logger.error(
"Exception in subprocess (%s): %s", type(e).__name__, error_text
)
# Communicating exception via the queue to the main process
parent_pipe.send(
WorkerResponse(
request_type=request.request_type,
success=False,
error_msg=error_text,
)
)
parent_pipe.close()
logger.error("Subprocess terminated due to exception: %s", e)
def _send(self, request_type: RequestType, payload: dict[str, Any]) -> None:
try:
assert self._parent_end is not None, (
"Parent end of pipe should be initialized"
)
self._parent_end.send(
WorkerRequest(
request_type=request_type,
payload=payload,
)
)
except OSError as e:
error_msg = "Child process terminated unexpectedly"
logger.error(
"Communication failed during %s request: %s", request_type.value, e
)
raise RuntimeError(error_msg) from e
def _recv(self) -> Optional[dict[str, Any]]:
try:
assert self._parent_end is not None, (
"Parent end of pipe should be initialized"
)
response = self._parent_end.recv()
if response.success is False:
error_msg = (
f"Unexpected response from worker process: {response.error_msg}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
return response.payload
except (EOFError, BrokenPipeError, ConnectionResetError) as e:
error_msg = f"Child process terminated unexpectedly: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e
def write(
self,
state_dict: Union[STATE_DICT, Future[STATE_DICT]],
path: str,
**kwargs: Any,
) -> Optional[Future[None]]:
logger.debug("Waiting for subprocess initialization to complete")
# wait until the process is started
self.process_creation_future.result()
return self._executor.submit(
self._write,
state_dict,
path,
**kwargs,
)
def _write(
self,
state_dict: Union[STATE_DICT, Future[STATE_DICT]],
path: str,
**kwargs: Any,
) -> None:
logger.debug("Starting checkpoint write to %s", path)
# wait for staging state_dict to be available
if isinstance(state_dict, Future):
logger.debug("Waiting for state_dict Future to resolve")
sd = state_dict.result()
else:
sd = state_dict
# Log state_dict info only if debug logging is enabled (performance-conscious)
if logger.isEnabledFor(logging.DEBUG):
if hasattr(sd, "keys"):
logger.debug("State_dict contains %d keys", len(sd.keys()))
self._send(
request_type=RequestType.WRITE_CHECKPOINT,
payload={
"state_dict": sd,
"path": path,
"kwargs": kwargs,
},
)
logger.debug("Waiting for write completion response")
# wait for response
self._recv()
logger.debug("Checkpoint write to %s completed successfully", path)
def close(self) -> None:
logger.debug(
"Closing CheckpointProcess for rank %d", self._rank_info.global_rank
)
self._executor.shutdown(wait=True, cancel_futures=True)
if self.process and self.process.processes[0].is_alive():
subprocess_pid = self.process.processes[0].pid
# send graceful termination to sub process
try:
self._parent_end.send(
WorkerRequest(
request_type=RequestType.TERMINATE_PROCESS,
payload={},
)
)
except BrokenPipeError:
logger.warning(
"BrokenPipeError when sending termination request - subprocess (PID: %d) may have already terminated",
subprocess_pid,
)
# subprocess terminated unexpectedly and below code will raise a
# ProcessExitedException.
logger.debug(
"Waiting for subprocess to terminate gracefully (timeout: %ds)",
self._config.subprocess_shutdown_timeout_secs,
)
try:
if not self.process.join(
timeout=self._config.subprocess_shutdown_timeout_secs
):
# graceful shutdown failed, kill the process.
logger.warning(
"Subprocess (PID: %d) did not terminate gracefully within %ds, killing it",
subprocess_pid,
self._config.subprocess_shutdown_timeout_secs,
)
self.process.processes[0].kill()
logger.info("Subprocess killed forcefully")
except ProcessExitedException as e:
logger.error(
"ProcessExitedException during subprocess termination: %s", e
)
raise
logger.debug("CheckpointProcess closed successfully")

View File

@ -3,9 +3,12 @@ import logging
from concurrent.futures import Future
from typing import Any, Optional, TypeVar
from .checkpoint_process import CheckpointProcess
from .checkpoint_reader import CheckpointReader
from .checkpoint_writer import CheckpointWriter
from .staging import CheckpointStager
from .types import STATE_DICT
from .utils import wrap_future
logger = logging.getLogger(__name__)
@ -190,3 +193,149 @@ class SyncCheckpointer(Checkpointer):
"""
self._writer.close()
logger.info("SyncCheckpointer closed")
class AsyncCheckpointer(Checkpointer):
"""
Asynchronous implementation of Checkpointer.
This class coordinates the writing and loading of model state dictionaries to and from storage
using asynchronous operations for saving. It provides efficient async checkpoint operations
with staging and background writing capabilities.
Attributes:
_reader: CheckpointReader for reading state dictionaries from storage.
_checkpoint_stager: Stager for async operations.
_checkpoint_process: Process for async operations.
_write_future: Future representing the ongoing async write operation.
Example:
checkpointer = AsyncCheckpointer(
reader=reader,
checkpoint_stager=stager,
checkpoint_process=process
)
stage_future, write_future = checkpointer.save(state_dict, path)
# ... do other work ...
write_future.result() # Wait for completion
"""
def __init__(
self,
checkpoint_stager: CheckpointStager,
checkpoint_process: CheckpointProcess,
reader: CheckpointReader,
):
"""
Initialize an asynchronous checkpointer.
Args:
checkpoint_stager: Stager for async operations.
checkpoint_process: Process for async operations.
reader: CheckpointReader for reading checkpoints from storage.
"""
self._reader = reader
self._checkpoint_stager = checkpoint_stager
self._checkpoint_process = checkpoint_process
self._write_future: Optional[Future[Any]] = None
def save(
self,
state_dict: STATE_DICT,
path: str,
**kwargs: Any,
) -> Optional[tuple[Future, Future]]:
"""
Save a state dictionary to storage asynchronously.
Args:
state_dict: The state dictionary to save.
path: The path where the checkpoint should be saved.
**kwargs: Additional keyword arguments to pass to the stager and writer.
Returns:
A tuple of (stage_future, write_future) representing the staging and writing operations.
Example:
stage_future, write_future = checkpointer.save(state_dict, "/path/to/checkpoint")
# ... do other work ...
write_future.result() # Wait for completion
"""
logger.info(
"Initiating checkpoint save to %s. Will wait for prev checkpoints to complete.",
path,
)
# Wait for previous checkpoint ops to finish and verify they are successful
if self._write_future is not None:
self._write_future.result()
logger.debug("Starting state dictionary staging")
staging_result = self._checkpoint_stager.stage(
state_dict=state_dict,
**kwargs,
)
logger.debug("Starting checkpoint write to %s", path)
self._write_future = self._checkpoint_process.write(
staging_result, path, **kwargs
)
logger.info("Checkpoint save to %s initiated", path)
# Return futures for the staging and writing operations
if self._write_future is not None:
return wrap_future(staging_result), self._write_future
else:
# This should not happen since we just assigned _write_future above
raise RuntimeError("Write future is unexpectedly None")
def load(
self,
path: str,
state_dict: Optional[STATE_DICT] = None,
*,
default_map_location: Any = None,
strict: bool = False,
**kwargs: Any,
) -> STATE_DICT:
"""
Load a state dictionary from storage.
Loading is always performed synchronously, even in AsyncCheckpointer.
Args:
path: The path from which to load the checkpoint.
state_dict: Optional state dictionary to update with loaded values.
If provided, only keys in this dictionary will be loaded.
default_map_location: Device mapping function or device name for relocating tensors.
strict: If True, raises an error when there are missing keys in the checkpoint.
**kwargs: Additional keyword arguments to pass to the reader.
Returns:
The loaded state dictionary.
Raises:
RuntimeError: If strict=True and there are missing keys in the checkpoint.
FileNotFoundError: If the checkpoint file is not found.
"""
logger.info("Loading checkpoint from %s", path)
loaded_state_dict, missing_keys = self._reader.read(
path=path,
state_dict=state_dict,
map_location=default_map_location,
**kwargs,
)
if strict and missing_keys is not None and missing_keys != []:
raise RuntimeError(f"Checkpoint at {path} is missing keys: {missing_keys}")
return loaded_state_dict
def close(self) -> None:
"""
Close the checkpointer and release any resources.
This method should be called when the checkpointer is no longer needed to ensure
proper cleanup of async resources.
"""
self._checkpoint_stager.close()
self._checkpoint_process.close()
logger.info("AsyncCheckpointer closed")

View File

@ -0,0 +1,44 @@
"""
Configuration classes for checkpointer construction.
This module provides configuration dataclasses that consolidate all
configuration options needed to construct checkpointers.
"""
from dataclasses import dataclass, field
from .barriers import BarrierConfig
from .checkpoint_process import CheckpointProcessConfig
from .checkpoint_writer import CheckpointWriterConfig
from .staging import CheckpointStagerConfig
@dataclass
class CheckpointerConfig:
"""
Configuration class for checkpointer construction.
This class consolidates the core component configuration options needed to construct
a checkpointer, providing a clean separation of concerns where each component
manages its own configuration.
Attributes:
writer_config: Configuration options for the checkpoint writer component.
barrier_config: Configuration for barrier construction and arguments.
staging_config: Configuration options for the async staging component.
process_config: Configuration options for the async checkpoint process component.
"""
writer_config: CheckpointWriterConfig = field(
default_factory=CheckpointWriterConfig
)
barrier_config: BarrierConfig = field(default_factory=BarrierConfig)
# Below configs are used for async checkpointing
staging_config: CheckpointStagerConfig = field(
default_factory=CheckpointStagerConfig
)
process_config: CheckpointProcessConfig = field(
default_factory=CheckpointProcessConfig
)

View File

@ -0,0 +1,216 @@
"""
Experimental staging module for PyTorch Distributed Checkpointing.
This module provides advanced staging capabilities for checkpoints including:
- Asynchronous staging using ThreadPoolExecutor
- Pinned memory allocation for faster CPU-GPU transfers
- Shared memory support for multi-process scenarios
- Non-blocking CUDA operations with stream synchronization
- Caching of frequently used storages for efficient memory management
- Automatic resource cleanup and memory management
Classes:
CheckpointStager: Abstract base class defining the staging interface
StagingOptions: Configuration dataclass for staging behavior
DefaultStager: Default implementation with comprehensive staging features
"""
import abc
import logging
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from logging import getLogger
from typing import Any, TypeVar, Union
import torch
from torch.distributed.checkpoint._state_dict_stager import StateDictStager
from .types import STATE_DICT
T = TypeVar("T")
logger = getLogger()
logger.setLevel(logging.INFO)
class CheckpointStager(abc.ABC):
"""
Abstract base class for checkpoint staging implementations.
CheckpointStager defines the interface that all staging implementations
must follow. Staging is the process of offloading state dictionaries
for async checkpointing.
"""
@abc.abstractmethod
def stage(
self,
state_dict: STATE_DICT,
**kwargs: Any,
) -> Union[STATE_DICT, Future[STATE_DICT]]:
"""
Stage a state dictionary for checkpointing.
Args:
state_dict: The state dictionary to stage
**kwargs: Additional staging parameters
Returns:
Either a staged state dictionary (synchronous) or a Future
that will resolve to the staged state dictionary (asynchronous)
"""
@abc.abstractmethod
def close(self) -> None:
"""
Clean up all resources used by the stager.
"""
@dataclass
class CheckpointStagerConfig:
"""
Configuration options for checkpoint staging behavior.
Attributes:
use_pinned_memory (bool): Enable pinned memory allocation for faster
CPU-GPU transfers. Requires CUDA to be available. Default: True
use_shared_memory (bool): Enable shared memory for multi-process
scenarios. Useful when multiple processes need access to the
same staged data. Default: True
use_async_staging (bool): Enable asynchronous staging using a
background thread pool. Allows overlapping computation with
staging operations. Requires CUDA. Default: True
use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory
copies with stream synchronization. Improves performance by
allowing CPU work to continue during GPU transfers. Default: True
Note:
CUDA-dependent features will raise exception if CUDA is not available.
"""
use_pinned_memory: bool = True
use_shared_memory: bool = True
use_async_staging: bool = True
use_cuda_non_blocking_copy: bool = True
class DefaultStager(CheckpointStager):
"""
DefaultStager provides a full-featured staging implementation that combines
multiple optimization techniques for efficient checkpoint preparation.
The staging process works as follows:
1. State dictionary is submitted for staging (sync or async)
2. Tensors are copied from GPU to optimized CPU storage
3. CUDA operations are synchronized if non-blocking copies are used
4. Staged state dictionary is returned or made available via Future
NOTE: state_dict should be deep-copyable object as staging will create a
copy of it.
Usage Patterns:
# Synchronous staging
stager = DefaultStager(CheckpointStagerConfig(use_async_staging=False))
staged_dict = stager.stage(state_dict)
stager.close()
# Asynchronous staging
stager = DefaultStager(CheckpointStagerConfig(use_async_staging=True))
future = stager.stage(state_dict)
# ... do other work ...
staged_dict = future.result()
stager.close()
# Context manager pattern (recommended)
with DefaultStager(config) as stager:
result = stager.stage(state_dict)
# Automatic cleanup on exit
Performance Considerations:
- Async staging provides best performance when model computation
can overlap with staging operations
- Pinned memory improves CPU-GPU transfer speeds but uses more memory
- Shared memory allows efficient IPC to checkpoint process
- Non-blocking copies reduce GPU idle time during memory transfers
Thread Safety:
DefaultStager is not thread-safe. Each thread should use its own
instance, or external synchronization should be provided.
"""
def __init__(
self,
config: CheckpointStagerConfig = CheckpointStagerConfig(),
):
self._config = config
self._state_dict_stager = StateDictStager(
pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory
)
self._staging_executor = None
self._staging_stream = None
if self._config.use_async_staging:
self._staging_executor = ThreadPoolExecutor(max_workers=1)
if torch.cuda.is_available():
# Note: stream needs to be initialized on the main thread after default cuda
# stream is setup/used to avoid the risk of accidentally reusing the main
# compute stream or in other cases kernels actually launching from the
# main thread.
self._staging_stream = torch.cuda.Stream()
if self._config.use_cuda_non_blocking_copy:
assert torch.cuda.is_available(), "Non-blocking copy requires CUDA"
def stage(
self,
state_dict: STATE_DICT,
**kwargs: Any,
) -> Union[STATE_DICT, Future[STATE_DICT]]:
if self._config.use_async_staging:
assert self._staging_executor is not None, (
"Staging executor should be initialized for async staging"
)
return self._staging_executor.submit(
self._stage,
state_dict,
**kwargs,
)
else:
return self._stage(state_dict, **kwargs)
def _stage(self, state_dict: STATE_DICT, **kwargs: Any) -> STATE_DICT:
state_dict = self._state_dict_stager.stage(
state_dict, non_blocking=self._config.use_cuda_non_blocking_copy, **kwargs
)
if self._config.use_cuda_non_blocking_copy:
assert self._staging_stream or not self._config.use_async_staging, (
"Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized."
)
# waits for the enqued copy operations to finish.
self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize()
return state_dict
def close(self) -> None:
"""
Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor
used for async staging operations and cleans up the underlying StateDictStager's
cached storages. Should be called when the stager is no longer needed to prevent
resource leaks, especially in long-running applications. After calling close(),
the stager should not be used for further staging operations.
state_dict should be deep-copyable object.
Example:
stager = DefaultStager(CheckpointStagerConfig(use_async_staging=True))
# ... do staging operations ...
stager.close() # Clean up all resources
"""
if self._staging_executor:
self._staging_executor.shutdown(wait=True)
self._state_dict_stager.close()

View File

@ -0,0 +1,42 @@
"""
Utility functions for the experimental checkpoint module.
This module contains helper functions and utilities used across the experimental
checkpoint functionality.
"""
from concurrent.futures import Future
from typing import Any
def wrap_future(original_result: Any) -> Future[None]:
"""
Wraps a result (Future or not) to return a Future with None result.
If the input is a Future, returns a new Future that completes with None when
the original Future completes successfully, or propagates any exception.
If the input is not a Future, returns a completed Future with None result.
Args:
original_result: The result to wrap (Future or any other value).
Returns:
A Future that completes with None on success or propagates exceptions.
"""
masked_future: Future[None] = Future()
if isinstance(original_result, Future):
def on_complete(_: Future[Any]) -> None:
try:
original_result.result()
masked_future.set_result(None)
except Exception as e:
masked_future.set_exception(e)
original_result.add_done_callback(on_complete)
else:
# Return a completed future with None result
masked_future.set_result(None)
return masked_future

View File

@ -223,6 +223,16 @@ class StateDictStager:
return y
def close(self):
"""
Clean up all cached storages and release associated resources.
This method clears the internal storage cache, allowing garbage collection
of cached CPU storages. Any pinned memory associated with cached storages
will be automatically unpinned through weak reference finalizers.
"""
self._cached_storage_mapping.clear()
@torch.no_grad()
def deepcopy_with_tensor_offload(self, x, memo=None, _nil=[], non_blocking=False): # noqa: B006
"""Deep copy operation on arbitrary Python objects with special handling for PyTorch tensors.

View File

@ -3,6 +3,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import io
import logging
import os
import shutil
import tempfile
@ -157,3 +158,36 @@ def with_temp_dir(
shutil.rmtree(self.temp_dir, ignore_errors=True)
return wrapper
def with_checkpoint_logging(
func: Optional[Callable] = None,
logger_name: str = "torch.distributed.checkpoint",
level: int = logging.INFO,
) -> Optional[Callable]:
"""
Wrapper to configure checkpoint logging for distributed tests.
Args:
func: The test function to wrap
logger_name: Name of the logger to configure (default: 'torch.distributed.checkpoint')
level: Logging level to set (default: logging.INFO)
"""
assert func is not None
@wraps(func)
def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
# Get the logger and store original level
target_logger = logging.getLogger(logger_name)
original_level = target_logger.level
# Set the desired logging level
target_logger.setLevel(level)
try:
func(self, *args, **kwargs)
finally:
# Restore original logging level
target_logger.setLevel(original_level)
return wrapper