mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add async checkpointing impl to experimental checkpointer and add a builder API (#156927)
1. Adds an AsyncCheckpointer with out-of-process checkpointing and state_dict_stager with shared memory, pinned memory and Zero Overhead Support. 2. Adds two conveinient functions to create sync/async checkpointers Differential Revision: [D77336833](https://our.internmc.facebook.com/intern/diff/D77336833/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156927 Approved by: https://github.com/pradeepfn
This commit is contained in:
parent
7081b8233a
commit
dd3e7170c2
165
test/distributed/checkpoint/_experimental/test_builder.py
Normal file
165
test/distributed/checkpoint/_experimental/test_builder.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint._experimental.barriers import BarrierConfig
|
||||
from torch.distributed.checkpoint._experimental.builder import (
|
||||
make_async_checkpointer,
|
||||
make_sync_checkpointer,
|
||||
)
|
||||
from torch.distributed.checkpoint._experimental.checkpointer import (
|
||||
AsyncCheckpointer,
|
||||
SyncCheckpointer,
|
||||
)
|
||||
from torch.distributed.checkpoint._experimental.config import CheckpointerConfig
|
||||
from torch.distributed.checkpoint._experimental.staging import CheckpointStagerConfig
|
||||
from torch.distributed.checkpoint._experimental.types import RankInfo
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestMakeCheckpointer(TestCase):
|
||||
def setUp(self) -> None:
|
||||
# Create a temporary directory for checkpoints
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
# Create real objects for testing
|
||||
self.rank_info = RankInfo(
|
||||
global_world_size=1,
|
||||
global_rank=0,
|
||||
)
|
||||
|
||||
# Create a test state dictionary
|
||||
self.state_dict = {
|
||||
"model": torch.nn.Linear(10, 5).state_dict(),
|
||||
"optimizer": {"param_groups": [{"lr": 0.01}]},
|
||||
"epoch": 5,
|
||||
"step": 1000,
|
||||
}
|
||||
|
||||
def tearDown(self) -> None:
|
||||
# Clean up the temporary directory
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_make_sync_checkpointer(self) -> None:
|
||||
"""Test creating a synchronous checkpointer using make_sync_checkpointer."""
|
||||
|
||||
# Create sync checkpointer using factory function with no barrier
|
||||
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
|
||||
checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info)
|
||||
|
||||
# Verify it's a SyncCheckpointer instance
|
||||
self.assertIsInstance(checkpointer, SyncCheckpointer)
|
||||
|
||||
# Test that it works for sync operations
|
||||
checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_sync")
|
||||
result = checkpointer.save(self.state_dict, checkpoint_path)
|
||||
self.assertIsNone(result) # Sync mode returns None
|
||||
|
||||
# Verify checkpoint was created
|
||||
checkpoint_file = os.path.join(
|
||||
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
|
||||
)
|
||||
self.assertTrue(os.path.exists(checkpoint_file))
|
||||
|
||||
# Test loading
|
||||
loaded_state_dict = checkpointer.load(checkpoint_path)
|
||||
self.assertEqual(loaded_state_dict["epoch"], 5)
|
||||
|
||||
def test_make_sync_checkpointer_with_config_first(self) -> None:
|
||||
"""Test creating a synchronous checkpointer with config as first parameter."""
|
||||
# Create sync checkpointer with config as first parameter
|
||||
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
|
||||
checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info)
|
||||
|
||||
# Verify it's a SyncCheckpointer instance
|
||||
self.assertIsInstance(checkpointer, SyncCheckpointer)
|
||||
|
||||
# Test that it works for sync operations
|
||||
checkpoint_path = os.path.join(
|
||||
self.temp_dir, "checkpoint_factory_sync_config_first"
|
||||
)
|
||||
result = checkpointer.save(self.state_dict, checkpoint_path)
|
||||
self.assertIsNone(result) # Sync mode returns None
|
||||
|
||||
# Verify checkpoint was created
|
||||
checkpoint_file = os.path.join(
|
||||
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
|
||||
)
|
||||
self.assertTrue(os.path.exists(checkpoint_file))
|
||||
|
||||
def test_make_sync_checkpointer_with_custom_config(self) -> None:
|
||||
"""Test creating a synchronous checkpointer with a custom config."""
|
||||
# Create a custom config with no barrier
|
||||
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
|
||||
|
||||
# Create sync checkpointer with the custom config
|
||||
checkpointer = make_sync_checkpointer(rank_info=self.rank_info, config=config)
|
||||
|
||||
# Verify it's a SyncCheckpointer instance
|
||||
self.assertIsInstance(checkpointer, SyncCheckpointer)
|
||||
|
||||
# Test that it works for sync operations
|
||||
checkpoint_path = os.path.join(
|
||||
self.temp_dir, "checkpoint_factory_sync_custom_config"
|
||||
)
|
||||
result = checkpointer.save(self.state_dict, checkpoint_path)
|
||||
self.assertIsNone(result) # Sync mode returns None
|
||||
|
||||
# Verify checkpoint was created
|
||||
checkpoint_file = os.path.join(
|
||||
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
|
||||
)
|
||||
self.assertTrue(os.path.exists(checkpoint_file))
|
||||
|
||||
# Test loading
|
||||
loaded_state_dict = checkpointer.load(checkpoint_path)
|
||||
self.assertEqual(loaded_state_dict["epoch"], 5)
|
||||
|
||||
def test_make_async_checkpointer(self) -> None:
|
||||
"""Test creating an asynchronous checkpointer using make_async_checkpointer."""
|
||||
# Create async checkpointer using factory function with default parameters
|
||||
config: CheckpointerConfig = CheckpointerConfig()
|
||||
config.staging_config = CheckpointStagerConfig(
|
||||
use_cuda_non_blocking_copy=torch.cuda.is_available(),
|
||||
use_pinned_memory=torch.cuda.is_available(),
|
||||
)
|
||||
checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info)
|
||||
|
||||
try:
|
||||
# Verify it's an AsyncCheckpointer instance
|
||||
self.assertIsInstance(checkpointer, AsyncCheckpointer)
|
||||
|
||||
# Test that it works for async operations
|
||||
checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_async")
|
||||
stage_future, write_future = checkpointer.save(
|
||||
self.state_dict, checkpoint_path
|
||||
)
|
||||
|
||||
# Verify futures are returned
|
||||
self.assertIsNotNone(stage_future)
|
||||
self.assertIsNotNone(write_future)
|
||||
|
||||
# Wait for completion
|
||||
stage_future.result()
|
||||
write_future.result()
|
||||
|
||||
# Verify checkpoint was created
|
||||
checkpoint_file = os.path.join(
|
||||
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
|
||||
)
|
||||
self.assertTrue(os.path.exists(checkpoint_file))
|
||||
|
||||
# Test loading
|
||||
loaded_state_dict = checkpointer.load(checkpoint_path)
|
||||
self.assertEqual(loaded_state_dict["epoch"], 5)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
checkpointer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -0,0 +1,465 @@
|
|||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint._experimental.checkpoint_process import (
|
||||
CheckpointProcess,
|
||||
CheckpointProcessConfig,
|
||||
RequestType,
|
||||
WorkerRequest,
|
||||
WorkerResponse,
|
||||
)
|
||||
from torch.distributed.checkpoint._experimental.checkpoint_writer import (
|
||||
CheckpointWriter,
|
||||
CheckpointWriterConfig,
|
||||
)
|
||||
from torch.distributed.checkpoint._experimental.types import RankInfo
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
def subprocess_init_fn(name: str, parent_pid: int) -> None:
|
||||
"""Initialize the subprocess with some basic checks.
|
||||
|
||||
This is similar to the subprocess_init_routine in checkpointing_test.py.
|
||||
"""
|
||||
assert name == "test-checkpointer", f"Unexpected subprocess name: {name}"
|
||||
assert os.getpid() != parent_pid, "This was supposed to run in a different process"
|
||||
assert os.getppid() == parent_pid, (
|
||||
"This was supposed to run as a child to main process"
|
||||
)
|
||||
|
||||
|
||||
def failing_subprocess_init_fn(name: str, parent_pid: int) -> None:
|
||||
"""Initialize function that raises an exception."""
|
||||
# Acknowledge parameters to avoid unused variable warnings
|
||||
_ = name
|
||||
_ = parent_pid
|
||||
raise RuntimeError("Subprocess initialization failed")
|
||||
|
||||
|
||||
def timedout_subprocess_init_fn(**kwargs: Any) -> None:
|
||||
# Acknowledge parameters to avoid unused variable warnings
|
||||
_ = kwargs
|
||||
time.sleep(3) # Simulate a long initialization
|
||||
|
||||
|
||||
def ckpt_writer_init_fn(**kwargs: Any) -> CheckpointWriter:
|
||||
"""Initialize a CheckpointWriter in the subprocess.
|
||||
|
||||
This function is called in the subprocess to create a CheckpointWriter instance.
|
||||
It's important that this function is defined at the module level so it can be pickled.
|
||||
"""
|
||||
return CheckpointWriter(
|
||||
config=kwargs.get("config"),
|
||||
rank_info=kwargs.get("rank_info"),
|
||||
)
|
||||
|
||||
|
||||
def failing_ckpt_writer_init_fn(**kwargs: Any) -> CheckpointWriter:
|
||||
"""Initialize function that raises an exception."""
|
||||
# Acknowledge parameters to avoid unused variable warnings
|
||||
_ = kwargs
|
||||
raise RuntimeError("CheckpointWriter initialization failed")
|
||||
|
||||
|
||||
def shared_tensor_verifier_init_fn(**kwargs: Any) -> CheckpointWriter:
|
||||
"""Initialize a CheckpointWriter that verifies shared memory tensors."""
|
||||
|
||||
class SharedTensorVerifier(CheckpointWriter):
|
||||
def __init__(self, config=None, rank_info=None, **init_kwargs):
|
||||
# Acknowledge unused kwargs to avoid linting warnings
|
||||
_ = init_kwargs
|
||||
super().__init__(
|
||||
config=config or CheckpointWriterConfig(),
|
||||
rank_info=rank_info,
|
||||
barrier=None,
|
||||
commit_hook=None,
|
||||
)
|
||||
|
||||
def write(self, state_dict, path, **__):
|
||||
# Acknowledge parameters to avoid unused variable warnings
|
||||
_ = path
|
||||
|
||||
# Verify shared memory tensor behavior directly with assertions
|
||||
if "shared_tensor" in state_dict:
|
||||
shared_tensor = state_dict["shared_tensor"]
|
||||
# Critical assertion: shared tensor should remain in shared memory in subprocess
|
||||
assert shared_tensor.is_shared(), (
|
||||
"Shared tensor should be in shared memory in subprocess"
|
||||
)
|
||||
|
||||
shared_tensor[0] = 42.0
|
||||
|
||||
if "regular_tensor" in state_dict:
|
||||
# Note: ForkingPickler moves regular tensors to shared memory during IPC - this is acceptable
|
||||
assert state_dict["regular_tensor"].is_shared(), (
|
||||
"Regular tensor should also be in shared memory in subprocess"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
verifier = SharedTensorVerifier(
|
||||
config=kwargs.get("config"),
|
||||
rank_info=kwargs.get("rank_info"),
|
||||
)
|
||||
return verifier
|
||||
|
||||
|
||||
class TestRequestTypes(TestCase):
|
||||
"""Test the request/response data structures."""
|
||||
|
||||
def test_request_type_enum(self) -> None:
|
||||
"""Test RequestType enum values."""
|
||||
self.assertEqual(RequestType.PING.value, "ping")
|
||||
self.assertEqual(RequestType.WRITE_CHECKPOINT.value, "write_checkpoint")
|
||||
self.assertEqual(RequestType.TERMINATE_PROCESS.value, "exit")
|
||||
|
||||
def test_worker_request(self) -> None:
|
||||
"""Test WorkerRequest dataclass."""
|
||||
request = WorkerRequest(request_type=RequestType.PING, payload={"test": "data"})
|
||||
self.assertEqual(request.request_type, RequestType.PING)
|
||||
self.assertEqual(request.payload["test"], "data")
|
||||
|
||||
def test_worker_response(self) -> None:
|
||||
"""Test WorkerResponse dataclass."""
|
||||
response = WorkerResponse(
|
||||
request_type=RequestType.PING,
|
||||
success=True,
|
||||
error_msg=None,
|
||||
payload={"result": "success"},
|
||||
)
|
||||
self.assertEqual(response.request_type, RequestType.PING)
|
||||
self.assertTrue(response.success)
|
||||
self.assertIsNone(response.error_msg)
|
||||
self.assertEqual(response.payload["result"], "success")
|
||||
|
||||
|
||||
class TestCheckpointProcessConfig(TestCase):
|
||||
"""Test CheckpointProcessConfig configuration."""
|
||||
|
||||
def test_default_options(self) -> None:
|
||||
"""Test default CheckpointProcessConfig."""
|
||||
options = CheckpointProcessConfig()
|
||||
# Test default values
|
||||
self.assertEqual(options.subprocess_init_timeout_secs, 30)
|
||||
self.assertEqual(options.subprocess_shutdown_timeout_secs, 60)
|
||||
|
||||
def test_custom_options(self) -> None:
|
||||
"""Test custom CheckpointProcessConfig."""
|
||||
options = CheckpointProcessConfig(
|
||||
subprocess_init_timeout_secs=10, subprocess_shutdown_timeout_secs=30
|
||||
)
|
||||
self.assertEqual(options.subprocess_init_timeout_secs, 10)
|
||||
self.assertEqual(options.subprocess_shutdown_timeout_secs, 30)
|
||||
|
||||
|
||||
class TestCheckpointProcess(TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Set up common test fixtures."""
|
||||
self.rank_info = RankInfo(
|
||||
global_world_size=1,
|
||||
global_rank=0,
|
||||
)
|
||||
self.writer_config = CheckpointWriterConfig()
|
||||
self.test_state_dict = {
|
||||
"model": torch.nn.Linear(10, 5).state_dict(),
|
||||
"optimizer": {"param_groups": [{"lr": 0.01}]},
|
||||
"epoch": 5,
|
||||
"step": 1000,
|
||||
}
|
||||
|
||||
def _create_checkpoint_process(
|
||||
self,
|
||||
subprocess_init_fn_override=None,
|
||||
subprocess_init_args_override=None,
|
||||
writer_init_fn_override=None,
|
||||
subprocess_init_timeout_secs=30,
|
||||
):
|
||||
"""Helper to create CheckpointProcess."""
|
||||
config = CheckpointProcessConfig(
|
||||
subprocess_init_timeout_secs=subprocess_init_timeout_secs,
|
||||
)
|
||||
|
||||
return CheckpointProcess(
|
||||
rank_info=self.rank_info,
|
||||
config=config,
|
||||
subprocess_init_fn=subprocess_init_fn_override or subprocess_init_fn,
|
||||
subprocess_init_args=subprocess_init_args_override
|
||||
or (
|
||||
"test-checkpointer",
|
||||
os.getpid(),
|
||||
),
|
||||
checkpoint_writer_init_fn=writer_init_fn_override or ckpt_writer_init_fn,
|
||||
checkpoint_writer_init_args={
|
||||
"config": self.writer_config,
|
||||
"rank_info": self.rank_info,
|
||||
},
|
||||
)
|
||||
|
||||
def test_checkpoint_process_initialization(self) -> None:
|
||||
"""Test that CheckpointProcess initializes and closes correctly."""
|
||||
checkpoint_process = self._create_checkpoint_process()
|
||||
|
||||
# Wait for the process creation future to complete
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
# Verify process is alive
|
||||
self.assertTrue(checkpoint_process.process.processes[0].is_alive())
|
||||
|
||||
checkpoint_process.close()
|
||||
|
||||
# Verify process is terminated
|
||||
self.assertFalse(checkpoint_process.process.processes[0].is_alive())
|
||||
|
||||
def test_checkpoint_write_sync_state_dict(self) -> None:
|
||||
"""Test writing a checkpoint with synchronous state dict."""
|
||||
checkpoint_process = self._create_checkpoint_process()
|
||||
|
||||
# Wait for initialization
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
# Create a temporary directory for the checkpoint
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
checkpoint_path = os.path.join(temp_dir, "test_checkpoint")
|
||||
|
||||
# Write checkpoint
|
||||
future = checkpoint_process.write(self.test_state_dict, checkpoint_path)
|
||||
|
||||
# Verify future is returned
|
||||
self.assertIsInstance(future, Future)
|
||||
|
||||
# Wait for completion
|
||||
future.result()
|
||||
|
||||
# Verify checkpoint file was created
|
||||
expected_file = os.path.join(
|
||||
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
|
||||
)
|
||||
self.assertTrue(os.path.exists(expected_file))
|
||||
|
||||
# Verify checkpoint content
|
||||
loaded_state_dict = torch.load(expected_file)
|
||||
self.assertIn("model", loaded_state_dict)
|
||||
self.assertIn("optimizer", loaded_state_dict)
|
||||
self.assertEqual(loaded_state_dict["epoch"], 5)
|
||||
self.assertEqual(loaded_state_dict["step"], 1000)
|
||||
|
||||
checkpoint_process.close()
|
||||
|
||||
def test_checkpoint_write_future_state_dict(self) -> None:
|
||||
"""Test writing a checkpoint with Future state dict."""
|
||||
checkpoint_process = self._create_checkpoint_process()
|
||||
|
||||
# Wait for initialization
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
# Create a Future that resolves to the state dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def get_state_dict():
|
||||
time.sleep(0.1) # Simulate some processing time
|
||||
return self.test_state_dict
|
||||
|
||||
future_state_dict = executor.submit(get_state_dict)
|
||||
|
||||
# Create a temporary directory for the checkpoint
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
checkpoint_path = os.path.join(temp_dir, "test_checkpoint")
|
||||
|
||||
# Write checkpoint with Future state dict
|
||||
write_future = checkpoint_process.write(future_state_dict, checkpoint_path)
|
||||
|
||||
# Wait for completion
|
||||
write_future.result()
|
||||
|
||||
# Verify checkpoint file was created
|
||||
expected_file = os.path.join(
|
||||
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
|
||||
)
|
||||
self.assertTrue(os.path.exists(expected_file))
|
||||
|
||||
executor.shutdown(wait=True)
|
||||
checkpoint_process.close()
|
||||
|
||||
def test_checkpoint_write_with_kwargs(self) -> None:
|
||||
"""Test checkpoint writing with additional kwargs."""
|
||||
checkpoint_process = self._create_checkpoint_process()
|
||||
|
||||
# Wait for initialization
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
checkpoint_path = os.path.join(temp_dir, "test_checkpoint")
|
||||
|
||||
# Write checkpoint with kwargs
|
||||
future = checkpoint_process.write(
|
||||
self.test_state_dict,
|
||||
checkpoint_path,
|
||||
custom_arg="test_value",
|
||||
another_arg=42,
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
future.result()
|
||||
|
||||
# Verify checkpoint was created
|
||||
expected_file = os.path.join(
|
||||
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
|
||||
)
|
||||
self.assertTrue(os.path.exists(expected_file))
|
||||
|
||||
checkpoint_process.close()
|
||||
|
||||
def test_subprocess_initialization_timeout(self) -> None:
|
||||
"""Test subprocess initialization timeout."""
|
||||
|
||||
# Create checkpoint process with a very short timeout by mocking the initialization
|
||||
checkpoint_process = self._create_checkpoint_process(
|
||||
subprocess_init_fn_override=timedout_subprocess_init_fn,
|
||||
subprocess_init_timeout_secs=1,
|
||||
)
|
||||
|
||||
# This should timeout
|
||||
with self.assertRaises(TimeoutError) as cm:
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
self.assertIn("Timed out", str(cm.exception))
|
||||
|
||||
def test_subprocess_initialization_failure(self) -> None:
|
||||
"""Test subprocess initialization failure."""
|
||||
checkpoint_process = self._create_checkpoint_process(
|
||||
subprocess_init_fn_override=failing_subprocess_init_fn
|
||||
)
|
||||
|
||||
# The subprocess should fail to initialize
|
||||
# We expect this to raise an exception when we try to use it
|
||||
with self.assertRaises(RuntimeError):
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
def test_graceful_termination(self) -> None:
|
||||
"""Test graceful termination of subprocess."""
|
||||
checkpoint_process = self._create_checkpoint_process()
|
||||
|
||||
checkpoint_process.process_creation_future.result()
|
||||
self.assertTrue(checkpoint_process.process.processes[0].is_alive())
|
||||
checkpoint_process.close()
|
||||
self.assertFalse(checkpoint_process.process.processes[0].is_alive())
|
||||
|
||||
def test_forced_termination(self) -> None:
|
||||
"""Test forced termination when graceful termination fails."""
|
||||
checkpoint_process = self._create_checkpoint_process()
|
||||
|
||||
# Wait for initialization
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
# Mock the join method to simulate timeout
|
||||
def mock_join(timeout=None):
|
||||
# Acknowledge timeout parameter to avoid unused variable warning
|
||||
_ = timeout
|
||||
return False # Simulate timeout
|
||||
|
||||
checkpoint_process.process.join = mock_join
|
||||
|
||||
# This should trigger forced termination
|
||||
checkpoint_process.close()
|
||||
|
||||
# Process should still be terminated (killed)
|
||||
# Note: This test might be flaky depending on timing
|
||||
|
||||
def test_communication_error_handling(self):
|
||||
"""Test handling of communication errors."""
|
||||
checkpoint_process = self._create_checkpoint_process()
|
||||
|
||||
# Wait for initialization
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
# Close the pipe to simulate communication failure
|
||||
checkpoint_process._parent_end.close()
|
||||
|
||||
# Attempting to write should raise an error
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
future = checkpoint_process.write(self.test_state_dict, "/tmp/test")
|
||||
future.result()
|
||||
|
||||
self.assertIn("Child process terminated unexpectedly", str(cm.exception))
|
||||
|
||||
def test_shared_memory_tensor_ipc(self):
|
||||
"""Test that shared memory tensors are backed by the same memory across processes."""
|
||||
|
||||
checkpoint_process = self._create_checkpoint_process(
|
||||
writer_init_fn_override=shared_tensor_verifier_init_fn,
|
||||
)
|
||||
|
||||
checkpoint_process.process_creation_future.result()
|
||||
|
||||
# Create tensors and put them in shared memory
|
||||
shared_tensor = torch.randn(100, 100)
|
||||
shared_tensor.share_memory_()
|
||||
|
||||
shared_tensor_data_ptr = shared_tensor.data_ptr()
|
||||
|
||||
regular_tensor = torch.randn(50, 50)
|
||||
# Don't put regular tensor in shared memory for comparison
|
||||
|
||||
# Verify initial shared memory status
|
||||
self.assertTrue(
|
||||
shared_tensor.is_shared(), "Shared tensor should be in shared memory"
|
||||
)
|
||||
self.assertFalse(
|
||||
regular_tensor.is_shared(), "Regular tensor should not be in shared memory"
|
||||
)
|
||||
|
||||
# Create state dict with mixed tensor types
|
||||
test_state_dict = {
|
||||
"shared_tensor": shared_tensor,
|
||||
"regular_tensor": regular_tensor,
|
||||
}
|
||||
|
||||
# Write to subprocess - the SharedTensorVerifier will:
|
||||
# 1. Verify the tensor is still in shared memory
|
||||
# 2. Check the marker value (42.0) to confirm same memory
|
||||
# 3. Modify specific positions to prove same memory access
|
||||
future = checkpoint_process.write(test_state_dict, "")
|
||||
|
||||
try:
|
||||
result = (
|
||||
future.result()
|
||||
) # This will raise an exception if the subprocess assertions fail
|
||||
self.assertIsNone(result) # SharedTensorVerifier returns None on success
|
||||
except Exception as e:
|
||||
self.fail(f"Subprocess assertions failed: {e}")
|
||||
|
||||
# assert shared tensor is still in same shared memory
|
||||
self.assertEqual(
|
||||
shared_tensor_data_ptr,
|
||||
shared_tensor.data_ptr(),
|
||||
"Shared tensor should still be in same shared memory",
|
||||
)
|
||||
self.assertTrue(
|
||||
shared_tensor.is_shared(), "Shared tensor should still be in shared memory"
|
||||
)
|
||||
|
||||
# CRITICAL TEST: Verify that modifications made by subprocess are visible in main process
|
||||
# This definitively proves that both processes access the same memory
|
||||
|
||||
self.assertAlmostEqual(
|
||||
shared_tensor[0][0],
|
||||
42.0,
|
||||
places=6,
|
||||
msg=f"Expected subprocess signature 42.0, got {shared_tensor[0]}. "
|
||||
f"Shared memory not working - subprocess modifications not visible!",
|
||||
)
|
||||
|
||||
checkpoint_process.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
216
test/distributed/checkpoint/_experimental/test_staging.py
Normal file
216
test/distributed/checkpoint/_experimental/test_staging.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
from concurrent.futures import Future
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint._experimental.staging import (
|
||||
CheckpointStagerConfig,
|
||||
DefaultStager,
|
||||
)
|
||||
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
|
||||
|
||||
|
||||
class TestDefaultStager(TestCase):
|
||||
def setUp(self) -> None:
|
||||
# Create a test state dictionary with various data types
|
||||
self.state_dict = {
|
||||
"model": torch.nn.Linear(10, 5).state_dict(),
|
||||
"optimizer": {"param_groups": [{"lr": 0.01}]},
|
||||
"epoch": 5,
|
||||
"step": 1000,
|
||||
"tensor": torch.randn(3, 4),
|
||||
"nested": {"inner_tensor": torch.ones(2, 2), "inner_value": 42},
|
||||
}
|
||||
|
||||
@requires_cuda
|
||||
def test_sync_staging(self) -> None:
|
||||
"""Test synchronous staging."""
|
||||
options = CheckpointStagerConfig(use_async_staging=False)
|
||||
stager = DefaultStager(options)
|
||||
|
||||
# Stage the state dict
|
||||
staged_dict = stager.stage(self.state_dict)
|
||||
|
||||
# Verify that a state dict is returned (not a Future)
|
||||
self.assertIsInstance(staged_dict, dict)
|
||||
|
||||
# Verify the staged state dictionary
|
||||
self.assertIn("model", staged_dict)
|
||||
self.assertIn("optimizer", staged_dict)
|
||||
self.assertEqual(staged_dict["epoch"], 5)
|
||||
self.assertEqual(staged_dict["step"], 1000)
|
||||
self.assertIn("tensor", staged_dict)
|
||||
self.assertIn("nested", staged_dict)
|
||||
|
||||
# Clean up
|
||||
stager.close()
|
||||
|
||||
@requires_cuda
|
||||
def test_async_staging(self) -> None:
|
||||
"""Test asynchronous staging."""
|
||||
options = CheckpointStagerConfig(use_async_staging=True)
|
||||
stager = DefaultStager(options)
|
||||
|
||||
# Stage the state dict
|
||||
result = stager.stage(self.state_dict)
|
||||
|
||||
# Verify that a Future is returned
|
||||
self.assertIsInstance(result, Future)
|
||||
|
||||
# Wait for the Future to complete
|
||||
staged_dict = result.result()
|
||||
|
||||
# Verify the staged state dictionary
|
||||
self.assertIn("model", staged_dict)
|
||||
self.assertIn("optimizer", staged_dict)
|
||||
self.assertEqual(staged_dict["epoch"], 5)
|
||||
self.assertEqual(staged_dict["step"], 1000)
|
||||
|
||||
# Clean up
|
||||
stager.close()
|
||||
|
||||
def test_cuda_non_blocking_without_cuda(self) -> None:
|
||||
"""Test that non-blocking copy fails when CUDA is not available."""
|
||||
if torch.cuda.is_available():
|
||||
self.skipTest("CUDA is available, cannot test CUDA unavailable scenario")
|
||||
|
||||
options = CheckpointStagerConfig(use_cuda_non_blocking_copy=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
DefaultStager(options)
|
||||
|
||||
def test_different_option_combinations(self) -> None:
|
||||
"""Test various combinations of staging options."""
|
||||
test_cases = [
|
||||
# All disabled
|
||||
CheckpointStagerConfig(
|
||||
use_pinned_memory=False,
|
||||
use_shared_memory=False,
|
||||
use_async_staging=False,
|
||||
use_cuda_non_blocking_copy=False,
|
||||
),
|
||||
# Only pinned memory
|
||||
CheckpointStagerConfig(
|
||||
use_pinned_memory=True,
|
||||
use_shared_memory=False,
|
||||
use_async_staging=False,
|
||||
use_cuda_non_blocking_copy=False,
|
||||
),
|
||||
# Only shared memory
|
||||
CheckpointStagerConfig(
|
||||
use_pinned_memory=False,
|
||||
use_shared_memory=True,
|
||||
use_async_staging=False,
|
||||
use_cuda_non_blocking_copy=False,
|
||||
),
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Only async staging
|
||||
test_cases.append(
|
||||
CheckpointStagerConfig(
|
||||
use_pinned_memory=torch.cuda.is_available(),
|
||||
use_shared_memory=False,
|
||||
use_async_staging=True,
|
||||
use_cuda_non_blocking_copy=False,
|
||||
)
|
||||
)
|
||||
# Only CUDA non-blocking copy
|
||||
test_cases.append(
|
||||
CheckpointStagerConfig(
|
||||
use_pinned_memory=torch.cuda.is_available(),
|
||||
use_shared_memory=False,
|
||||
use_async_staging=False,
|
||||
use_cuda_non_blocking_copy=torch.cuda.is_available(),
|
||||
)
|
||||
)
|
||||
|
||||
for options in test_cases:
|
||||
with self.subTest(options=options):
|
||||
stager = DefaultStager(options)
|
||||
|
||||
# Test staging works with these options
|
||||
if options.use_async_staging and torch.cuda.is_available():
|
||||
result = stager.stage(self.state_dict)
|
||||
self.assertIsInstance(result, Future)
|
||||
staged_dict = result.result()
|
||||
else:
|
||||
staged_dict = stager.stage(self.state_dict)
|
||||
|
||||
self.assertIsInstance(staged_dict, dict)
|
||||
self.assertIn("model", staged_dict)
|
||||
|
||||
stager.close()
|
||||
|
||||
@requires_cuda
|
||||
def test_cuda_tensors_staging(self) -> None:
|
||||
"""Test staging with CUDA tensors."""
|
||||
# Create state dict with CUDA tensors
|
||||
cuda_state_dict = {
|
||||
"cuda_tensor": torch.randn(3, 4).cuda(),
|
||||
"cpu_tensor": torch.randn(2, 3),
|
||||
"mixed_model": {
|
||||
"weight": torch.randn(5, 5).cuda(),
|
||||
"bias": torch.randn(5).cuda(),
|
||||
},
|
||||
}
|
||||
|
||||
options = CheckpointStagerConfig(use_async_staging=False)
|
||||
stager = DefaultStager(options)
|
||||
|
||||
staged_dict = stager.stage(cuda_state_dict)
|
||||
assert isinstance(staged_dict, dict)
|
||||
|
||||
# Verify tensors are staged (should be moved to CPU)
|
||||
self.assertIn("cuda_tensor", staged_dict)
|
||||
self.assertIn("cpu_tensor", staged_dict)
|
||||
self.assertIn("mixed_model", staged_dict)
|
||||
|
||||
stager.close()
|
||||
|
||||
@requires_cuda
|
||||
def test_resource_cleanup(self) -> None:
|
||||
"""Test that resources are properly cleaned up."""
|
||||
options = CheckpointStagerConfig(use_async_staging=False)
|
||||
stager = DefaultStager(options)
|
||||
|
||||
# Verify initial state
|
||||
self.assertIsNotNone(stager._state_dict_stager)
|
||||
|
||||
# Close and verify cleanup
|
||||
stager.close()
|
||||
|
||||
def test_multiple_staging_operations(self) -> None:
|
||||
"""Test multiple staging operations with the same stager."""
|
||||
options = CheckpointStagerConfig(
|
||||
use_async_staging=False,
|
||||
use_pinned_memory=torch.cuda.is_available(),
|
||||
use_shared_memory=False,
|
||||
use_cuda_non_blocking_copy=torch.cuda.is_available(),
|
||||
)
|
||||
stager = DefaultStager(options)
|
||||
|
||||
# Stage multiple different state dicts
|
||||
state_dicts = [
|
||||
{"model1": torch.nn.Linear(5, 3).state_dict()},
|
||||
{"model2": torch.nn.Conv2d(3, 16, 3).state_dict()},
|
||||
{"optimizer": {"lr": 0.001, "momentum": 0.9}},
|
||||
]
|
||||
|
||||
staged_results = []
|
||||
for state_dict in state_dicts:
|
||||
staged_dict = stager.stage(state_dict)
|
||||
staged_results.append(staged_dict)
|
||||
|
||||
# Verify all staging operations succeeded
|
||||
self.assertEqual(len(staged_results), 3)
|
||||
for i, result in enumerate(staged_results):
|
||||
self.assertIsInstance(result, dict)
|
||||
# Verify the result contains the expected keys
|
||||
for key in state_dicts[i].keys():
|
||||
self.assertIn(key, result)
|
||||
|
||||
stager.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -19,9 +19,14 @@ from .barriers import (
|
|||
create_barrier_from_config,
|
||||
TCPStoreBarrier,
|
||||
)
|
||||
from .builder import make_async_checkpointer, make_sync_checkpointer
|
||||
from .checkpoint_reader import CheckpointReader
|
||||
from .checkpoint_writer import CheckpointWriter, CheckpointWriterConfig, WriterHook
|
||||
from .checkpointer import AsyncCheckpointer, Checkpointer, SyncCheckpointer
|
||||
from .config import CheckpointerConfig
|
||||
from .staging import CheckpointStager, CheckpointStagerConfig, DefaultStager
|
||||
from .types import RankInfo, STATE_DICT
|
||||
from .utils import wrap_future
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -31,8 +36,18 @@ __all__ = [
|
|||
"CheckpointWriter",
|
||||
"CheckpointWriterConfig",
|
||||
"WriterHook",
|
||||
"Checkpointer",
|
||||
"SyncCheckpointer",
|
||||
"AsyncCheckpointer",
|
||||
"CheckpointerConfig",
|
||||
"BarrierConfig",
|
||||
"create_barrier_from_config",
|
||||
"CheckpointStager",
|
||||
"CheckpointStagerConfig",
|
||||
"DefaultStager",
|
||||
"RankInfo",
|
||||
"STATE_DICT",
|
||||
"wrap_future",
|
||||
"make_sync_checkpointer",
|
||||
"make_async_checkpointer",
|
||||
]
|
||||
|
|
|
|||
173
torch/distributed/checkpoint/_experimental/builder.py
Normal file
173
torch/distributed/checkpoint/_experimental/builder.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
"""
|
||||
Factory functions for creating checkpointer instances with sensible defaults.
|
||||
|
||||
This module provides high-level factory functions that simplify the creation
|
||||
of checkpointer instances by automatically handling component initialization
|
||||
and configuration with reasonable defaults.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from .barriers import create_barrier_from_config
|
||||
from .checkpoint_process import CheckpointProcess
|
||||
from .checkpoint_reader import CheckpointReader
|
||||
from .checkpoint_writer import CheckpointWriter, CheckpointWriterConfig, WriterHook
|
||||
from .checkpointer import AsyncCheckpointer, SyncCheckpointer
|
||||
from .config import CheckpointerConfig
|
||||
from .staging import DefaultStager
|
||||
from .types import RankInfo
|
||||
|
||||
|
||||
def _get_default_rank_info() -> RankInfo:
|
||||
"""
|
||||
Get default rank information from the current distributed environment.
|
||||
|
||||
Returns:
|
||||
RankInfo: Rank information from the default process group if initialized,
|
||||
otherwise single-rank fallback.
|
||||
"""
|
||||
if dist.is_initialized():
|
||||
return RankInfo(
|
||||
global_world_size=dist.get_world_size(),
|
||||
global_rank=dist.get_rank(),
|
||||
)
|
||||
else:
|
||||
# Single-rank fallback
|
||||
return RankInfo(global_world_size=1, global_rank=0)
|
||||
|
||||
|
||||
def default_subprocess_init_fn(*_: Any) -> None:
|
||||
"""Default subprocess initialization function (no-op)."""
|
||||
|
||||
|
||||
def default_writer_init_fn(rank_info: RankInfo) -> CheckpointWriter:
|
||||
"""Default checkpoint writer initialization function."""
|
||||
return CheckpointWriter(
|
||||
config=CheckpointWriterConfig(),
|
||||
rank_info=rank_info,
|
||||
)
|
||||
|
||||
|
||||
def make_sync_checkpointer(
|
||||
config: CheckpointerConfig = CheckpointerConfig(),
|
||||
rank_info: Optional[RankInfo] = None,
|
||||
commit_hook: Optional[WriterHook] = None,
|
||||
) -> SyncCheckpointer:
|
||||
"""
|
||||
Factory function to create a SyncCheckpointer instance with sensible defaults.
|
||||
|
||||
This function creates a synchronous checkpointer with default components, automatically
|
||||
detecting rank information from the default process group if available, and using the
|
||||
provided component configurations.
|
||||
|
||||
Args:
|
||||
config: CheckpointerConfig containing component-specific configurations
|
||||
(writer_config, staging_config, process_config). Defaults to CheckpointerConfig().
|
||||
rank_info: RankInfo for distributed training. Defaults to auto-detection from
|
||||
the default PyTorch distributed process group if initialized, otherwise
|
||||
falls back to single-rank (world_size=1, rank=0).
|
||||
commit_hook: Optional hook for custom actions before and after checkpoint commits.
|
||||
|
||||
Returns:
|
||||
SyncCheckpointer: A configured synchronous checkpointer instance.
|
||||
|
||||
Examples:
|
||||
# Simplest usage - auto-detect rank, default config
|
||||
checkpointer = make_sync_checkpointer()
|
||||
|
||||
# Explicit rank configuration
|
||||
checkpointer = make_sync_checkpointer(
|
||||
rank_info=RankInfo(global_world_size=4, global_rank=0)
|
||||
)
|
||||
|
||||
# Disable barrier
|
||||
from .barriers import BarrierConfig
|
||||
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
|
||||
checkpointer = make_sync_checkpointer(config=config)
|
||||
"""
|
||||
if rank_info is None:
|
||||
rank_info = _get_default_rank_info()
|
||||
|
||||
reader = CheckpointReader(
|
||||
rank_info=rank_info,
|
||||
)
|
||||
|
||||
barrier = create_barrier_from_config(config.barrier_config)
|
||||
|
||||
writer = CheckpointWriter(
|
||||
config=config.writer_config,
|
||||
rank_info=rank_info,
|
||||
barrier=barrier,
|
||||
commit_hook=commit_hook,
|
||||
)
|
||||
|
||||
return SyncCheckpointer(
|
||||
writer=writer,
|
||||
reader=reader,
|
||||
)
|
||||
|
||||
|
||||
def make_async_checkpointer(
|
||||
config: CheckpointerConfig = CheckpointerConfig(),
|
||||
rank_info: Optional[RankInfo] = None,
|
||||
subprocess_init_fn: Callable[..., None] = default_subprocess_init_fn,
|
||||
subprocess_init_args: tuple[Any, ...] = (),
|
||||
checkpoint_writer_init_fn: Callable[..., CheckpointWriter] = default_writer_init_fn,
|
||||
checkpoint_writer_init_args: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncCheckpointer:
|
||||
"""
|
||||
Factory function to create an AsyncCheckpointer instance with sensible defaults.
|
||||
|
||||
This function creates an asynchronous checkpointer using the provided configuration,
|
||||
automatically detecting rank information if not provided.
|
||||
|
||||
Args:
|
||||
config: CheckpointerConfig containing component-specific configurations.
|
||||
rank_info: RankInfo for distributed training. Defaults to auto-detection.
|
||||
subprocess_init_fn: Function to initialize the subprocess. Defaults to no-op.
|
||||
subprocess_init_args: Arguments to pass to subprocess_init_fn.
|
||||
checkpoint_writer_init_fn: Function to create CheckpointWriter instance.
|
||||
checkpoint_writer_init_args: Arguments to pass to checkpoint_writer_init_fn.
|
||||
|
||||
Returns:
|
||||
AsyncCheckpointer: A configured asynchronous checkpointer instance.
|
||||
|
||||
Examples:
|
||||
# Create with default config
|
||||
checkpointer = make_async_checkpointer()
|
||||
|
||||
# Create with custom init functions
|
||||
checkpointer = make_async_checkpointer(
|
||||
subprocess_init_fn=my_subprocess_init_fn,
|
||||
checkpoint_writer_init_fn=my_writer_init_fn
|
||||
)
|
||||
"""
|
||||
if rank_info is None:
|
||||
rank_info = _get_default_rank_info()
|
||||
|
||||
reader = CheckpointReader(
|
||||
rank_info=rank_info,
|
||||
)
|
||||
|
||||
checkpoint_stager = DefaultStager(
|
||||
config=config.staging_config,
|
||||
)
|
||||
|
||||
checkpoint_writer_init_args = checkpoint_writer_init_args or {}
|
||||
|
||||
checkpoint_process = CheckpointProcess(
|
||||
rank_info=rank_info,
|
||||
config=config.process_config,
|
||||
subprocess_init_fn=subprocess_init_fn,
|
||||
subprocess_init_args=subprocess_init_args,
|
||||
checkpoint_writer_init_fn=checkpoint_writer_init_fn,
|
||||
checkpoint_writer_init_args=checkpoint_writer_init_args,
|
||||
)
|
||||
|
||||
return AsyncCheckpointer(
|
||||
checkpoint_stager=checkpoint_stager,
|
||||
checkpoint_process=checkpoint_process,
|
||||
reader=reader,
|
||||
)
|
||||
361
torch/distributed/checkpoint/_experimental/checkpoint_process.py
Normal file
361
torch/distributed/checkpoint/_experimental/checkpoint_process.py
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from torch.multiprocessing.spawn import ProcessExitedException
|
||||
|
||||
from .checkpoint_writer import CheckpointWriter
|
||||
from .types import RankInfo, STATE_DICT
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointProcessConfig:
|
||||
"""
|
||||
Configuration options for the CheckpointProcess.
|
||||
|
||||
This class provides configuration options for the checkpoint process,
|
||||
including initialization functions, timeouts, and writer configuration.
|
||||
|
||||
Attributes:
|
||||
subprocess_init_timeout_secs: Maximum time in seconds to wait for subprocess initialization.
|
||||
subprocess_shutdown_timeout_secs: Maximum time in seconds to wait for subprocess shutdown.
|
||||
"""
|
||||
|
||||
subprocess_init_timeout_secs: int = 30
|
||||
subprocess_shutdown_timeout_secs: int = 60
|
||||
|
||||
|
||||
class RequestType(Enum):
|
||||
PING = "ping"
|
||||
WRITE_CHECKPOINT = "write_checkpoint"
|
||||
TERMINATE_PROCESS = "exit"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerRequest:
|
||||
"""
|
||||
A dataclass for storing the command to be sent to the worker process.
|
||||
Note: This relies on pickling to send the command to the worker process. Handle
|
||||
backward compatibility accordingly.
|
||||
"""
|
||||
|
||||
request_type: RequestType
|
||||
payload: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerResponse:
|
||||
request_type: RequestType
|
||||
success: bool
|
||||
error_msg: Optional[str] = None
|
||||
payload: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class CheckpointProcess:
|
||||
"""
|
||||
A checkpoint writer that writes checkpoints to a remote process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank_info: RankInfo,
|
||||
config: CheckpointProcessConfig,
|
||||
subprocess_init_fn: Callable[[Any], None],
|
||||
subprocess_init_args: tuple[Any, ...],
|
||||
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
|
||||
checkpoint_writer_init_args: dict[str, Any],
|
||||
):
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._rank_info = rank_info
|
||||
self._config = config
|
||||
self._subprocess_init_fn = subprocess_init_fn
|
||||
self._subprocess_init_args = subprocess_init_args
|
||||
self._checkpoint_writer_init_fn = checkpoint_writer_init_fn
|
||||
self._checkpoint_writer_init_args = checkpoint_writer_init_args
|
||||
self.process = None
|
||||
self._parent_end: Optional[Connection] = None
|
||||
self._child_end: Optional[Connection] = None
|
||||
|
||||
self.process_creation_future = self._executor.submit(
|
||||
self._create_subprocess,
|
||||
config,
|
||||
)
|
||||
|
||||
def _create_subprocess(
|
||||
self,
|
||||
config: CheckpointProcessConfig,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Creating checkpoint subprocess for rank %d", self._rank_info.global_rank
|
||||
)
|
||||
|
||||
spawn_context = mp.get_context("spawn")
|
||||
self._parent_end, child_end = spawn_context.Pipe()
|
||||
|
||||
# Known workaround for https://github.com/pytorch/pytorch/issues/37377
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
|
||||
|
||||
logger.debug("Spawning subprocess for rank_info=%s", self._rank_info)
|
||||
self.process = mp.spawn(
|
||||
fn=CheckpointProcess._subprocess,
|
||||
args=(
|
||||
self._rank_info,
|
||||
child_end,
|
||||
self._subprocess_init_fn,
|
||||
self._subprocess_init_args,
|
||||
self._checkpoint_writer_init_fn,
|
||||
self._checkpoint_writer_init_args,
|
||||
),
|
||||
nprocs=1,
|
||||
join=False,
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
# close the child end of the pipe so recv on it will fail
|
||||
# fast when the child process is terminated unexpectedly.
|
||||
child_end.close()
|
||||
self._send(
|
||||
request_type=RequestType.PING,
|
||||
payload={},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Waiting for checkpoint subprocess to initialize (timeout: %ds)",
|
||||
config.subprocess_init_timeout_secs,
|
||||
)
|
||||
|
||||
# wait for the timeout or a response from subprocess
|
||||
assert self._parent_end is not None, "Parent end of pipe should be initialized"
|
||||
if not self._parent_end.poll(timeout=config.subprocess_init_timeout_secs):
|
||||
msg = f"Timed out after {config.subprocess_init_timeout_secs}s waiting for checkpoint subprocess to initialize"
|
||||
logger.error(msg)
|
||||
raise TimeoutError(msg)
|
||||
|
||||
self._recv()
|
||||
logger.info("Checkpoint subprocess initialized successfully")
|
||||
|
||||
@staticmethod
|
||||
def _subprocess(
|
||||
sub_rank: int,
|
||||
rank_info: RankInfo,
|
||||
parent_pipe: Connection,
|
||||
subprocess_init_fn: Callable[[Any], None],
|
||||
subprocess_init_args: tuple[Any, ...],
|
||||
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
|
||||
checkpoint_writer_init_args: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug(
|
||||
"Checkpoint subprocess started for rank %d/%d (PID: %d)",
|
||||
rank_info.global_rank,
|
||||
rank_info.global_world_size,
|
||||
os.getpid(),
|
||||
)
|
||||
|
||||
assert sub_rank == 0, "We need only one checkpointer per parent training"
|
||||
request = WorkerRequest(request_type=RequestType.PING, payload={})
|
||||
|
||||
try:
|
||||
# Calling initialize callback, so we can perform app-specific initialization of the subprocess.
|
||||
subprocess_init_fn(*subprocess_init_args)
|
||||
|
||||
# Initialize checkpoint writer - automatically include rank_info in init_args
|
||||
writer_init_args = dict(checkpoint_writer_init_args)
|
||||
if "rank_info" not in writer_init_args:
|
||||
writer_init_args["rank_info"] = rank_info
|
||||
checkpoint_writer = checkpoint_writer_init_fn(**writer_init_args)
|
||||
|
||||
while True:
|
||||
request = parent_pipe.recv()
|
||||
|
||||
if request.request_type == RequestType.PING:
|
||||
parent_pipe.send(
|
||||
WorkerResponse(request_type=RequestType.PING, success=True)
|
||||
)
|
||||
elif request.request_type == RequestType.WRITE_CHECKPOINT:
|
||||
path = request.payload["path"]
|
||||
logger.info("Writing checkpoint to %s", path)
|
||||
|
||||
checkpoint_writer.write(
|
||||
state_dict=request.payload["state_dict"],
|
||||
path=path,
|
||||
**request.payload["kwargs"],
|
||||
)
|
||||
|
||||
logger.info("Checkpoint written successfully to %s", path)
|
||||
parent_pipe.send(
|
||||
WorkerResponse(RequestType.WRITE_CHECKPOINT, success=True)
|
||||
)
|
||||
elif request.request_type == RequestType.TERMINATE_PROCESS:
|
||||
logger.debug("Received termination request.")
|
||||
parent_pipe.send(
|
||||
WorkerResponse(RequestType.TERMINATE_PROCESS, success=True)
|
||||
)
|
||||
logger.info("Subprocess terminated gracefully")
|
||||
break
|
||||
else:
|
||||
error_msg = f"Unknown request type: {request.request_type}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
error_text = traceback.format_exc()
|
||||
logger.error(
|
||||
"Exception in subprocess (%s): %s", type(e).__name__, error_text
|
||||
)
|
||||
|
||||
# Communicating exception via the queue to the main process
|
||||
parent_pipe.send(
|
||||
WorkerResponse(
|
||||
request_type=request.request_type,
|
||||
success=False,
|
||||
error_msg=error_text,
|
||||
)
|
||||
)
|
||||
parent_pipe.close()
|
||||
logger.error("Subprocess terminated due to exception: %s", e)
|
||||
|
||||
def _send(self, request_type: RequestType, payload: dict[str, Any]) -> None:
|
||||
try:
|
||||
assert self._parent_end is not None, (
|
||||
"Parent end of pipe should be initialized"
|
||||
)
|
||||
self._parent_end.send(
|
||||
WorkerRequest(
|
||||
request_type=request_type,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
except OSError as e:
|
||||
error_msg = "Child process terminated unexpectedly"
|
||||
logger.error(
|
||||
"Communication failed during %s request: %s", request_type.value, e
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _recv(self) -> Optional[dict[str, Any]]:
|
||||
try:
|
||||
assert self._parent_end is not None, (
|
||||
"Parent end of pipe should be initialized"
|
||||
)
|
||||
response = self._parent_end.recv()
|
||||
if response.success is False:
|
||||
error_msg = (
|
||||
f"Unexpected response from worker process: {response.error_msg}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return response.payload
|
||||
except (EOFError, BrokenPipeError, ConnectionResetError) as e:
|
||||
error_msg = f"Child process terminated unexpectedly: {e}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def write(
|
||||
self,
|
||||
state_dict: Union[STATE_DICT, Future[STATE_DICT]],
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> Optional[Future[None]]:
|
||||
logger.debug("Waiting for subprocess initialization to complete")
|
||||
|
||||
# wait until the process is started
|
||||
self.process_creation_future.result()
|
||||
|
||||
return self._executor.submit(
|
||||
self._write,
|
||||
state_dict,
|
||||
path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _write(
|
||||
self,
|
||||
state_dict: Union[STATE_DICT, Future[STATE_DICT]],
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
logger.debug("Starting checkpoint write to %s", path)
|
||||
|
||||
# wait for staging state_dict to be available
|
||||
if isinstance(state_dict, Future):
|
||||
logger.debug("Waiting for state_dict Future to resolve")
|
||||
sd = state_dict.result()
|
||||
else:
|
||||
sd = state_dict
|
||||
|
||||
# Log state_dict info only if debug logging is enabled (performance-conscious)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if hasattr(sd, "keys"):
|
||||
logger.debug("State_dict contains %d keys", len(sd.keys()))
|
||||
|
||||
self._send(
|
||||
request_type=RequestType.WRITE_CHECKPOINT,
|
||||
payload={
|
||||
"state_dict": sd,
|
||||
"path": path,
|
||||
"kwargs": kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug("Waiting for write completion response")
|
||||
# wait for response
|
||||
self._recv()
|
||||
logger.debug("Checkpoint write to %s completed successfully", path)
|
||||
|
||||
def close(self) -> None:
|
||||
logger.debug(
|
||||
"Closing CheckpointProcess for rank %d", self._rank_info.global_rank
|
||||
)
|
||||
self._executor.shutdown(wait=True, cancel_futures=True)
|
||||
|
||||
if self.process and self.process.processes[0].is_alive():
|
||||
subprocess_pid = self.process.processes[0].pid
|
||||
# send graceful termination to sub process
|
||||
try:
|
||||
self._parent_end.send(
|
||||
WorkerRequest(
|
||||
request_type=RequestType.TERMINATE_PROCESS,
|
||||
payload={},
|
||||
)
|
||||
)
|
||||
except BrokenPipeError:
|
||||
logger.warning(
|
||||
"BrokenPipeError when sending termination request - subprocess (PID: %d) may have already terminated",
|
||||
subprocess_pid,
|
||||
)
|
||||
# subprocess terminated unexpectedly and below code will raise a
|
||||
# ProcessExitedException.
|
||||
|
||||
logger.debug(
|
||||
"Waiting for subprocess to terminate gracefully (timeout: %ds)",
|
||||
self._config.subprocess_shutdown_timeout_secs,
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.process.join(
|
||||
timeout=self._config.subprocess_shutdown_timeout_secs
|
||||
):
|
||||
# graceful shutdown failed, kill the process.
|
||||
logger.warning(
|
||||
"Subprocess (PID: %d) did not terminate gracefully within %ds, killing it",
|
||||
subprocess_pid,
|
||||
self._config.subprocess_shutdown_timeout_secs,
|
||||
)
|
||||
self.process.processes[0].kill()
|
||||
logger.info("Subprocess killed forcefully")
|
||||
except ProcessExitedException as e:
|
||||
logger.error(
|
||||
"ProcessExitedException during subprocess termination: %s", e
|
||||
)
|
||||
raise
|
||||
|
||||
logger.debug("CheckpointProcess closed successfully")
|
||||
|
|
@ -3,9 +3,12 @@ import logging
|
|||
from concurrent.futures import Future
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
from .checkpoint_process import CheckpointProcess
|
||||
from .checkpoint_reader import CheckpointReader
|
||||
from .checkpoint_writer import CheckpointWriter
|
||||
from .staging import CheckpointStager
|
||||
from .types import STATE_DICT
|
||||
from .utils import wrap_future
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -190,3 +193,149 @@ class SyncCheckpointer(Checkpointer):
|
|||
"""
|
||||
self._writer.close()
|
||||
logger.info("SyncCheckpointer closed")
|
||||
|
||||
|
||||
class AsyncCheckpointer(Checkpointer):
|
||||
"""
|
||||
Asynchronous implementation of Checkpointer.
|
||||
|
||||
This class coordinates the writing and loading of model state dictionaries to and from storage
|
||||
using asynchronous operations for saving. It provides efficient async checkpoint operations
|
||||
with staging and background writing capabilities.
|
||||
|
||||
Attributes:
|
||||
_reader: CheckpointReader for reading state dictionaries from storage.
|
||||
_checkpoint_stager: Stager for async operations.
|
||||
_checkpoint_process: Process for async operations.
|
||||
_write_future: Future representing the ongoing async write operation.
|
||||
|
||||
Example:
|
||||
checkpointer = AsyncCheckpointer(
|
||||
reader=reader,
|
||||
checkpoint_stager=stager,
|
||||
checkpoint_process=process
|
||||
)
|
||||
stage_future, write_future = checkpointer.save(state_dict, path)
|
||||
# ... do other work ...
|
||||
write_future.result() # Wait for completion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_stager: CheckpointStager,
|
||||
checkpoint_process: CheckpointProcess,
|
||||
reader: CheckpointReader,
|
||||
):
|
||||
"""
|
||||
Initialize an asynchronous checkpointer.
|
||||
|
||||
Args:
|
||||
checkpoint_stager: Stager for async operations.
|
||||
checkpoint_process: Process for async operations.
|
||||
reader: CheckpointReader for reading checkpoints from storage.
|
||||
"""
|
||||
self._reader = reader
|
||||
self._checkpoint_stager = checkpoint_stager
|
||||
self._checkpoint_process = checkpoint_process
|
||||
self._write_future: Optional[Future[Any]] = None
|
||||
|
||||
def save(
|
||||
self,
|
||||
state_dict: STATE_DICT,
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> Optional[tuple[Future, Future]]:
|
||||
"""
|
||||
Save a state dictionary to storage asynchronously.
|
||||
|
||||
Args:
|
||||
state_dict: The state dictionary to save.
|
||||
path: The path where the checkpoint should be saved.
|
||||
**kwargs: Additional keyword arguments to pass to the stager and writer.
|
||||
|
||||
Returns:
|
||||
A tuple of (stage_future, write_future) representing the staging and writing operations.
|
||||
|
||||
Example:
|
||||
stage_future, write_future = checkpointer.save(state_dict, "/path/to/checkpoint")
|
||||
# ... do other work ...
|
||||
write_future.result() # Wait for completion
|
||||
"""
|
||||
logger.info(
|
||||
"Initiating checkpoint save to %s. Will wait for prev checkpoints to complete.",
|
||||
path,
|
||||
)
|
||||
# Wait for previous checkpoint ops to finish and verify they are successful
|
||||
if self._write_future is not None:
|
||||
self._write_future.result()
|
||||
|
||||
logger.debug("Starting state dictionary staging")
|
||||
staging_result = self._checkpoint_stager.stage(
|
||||
state_dict=state_dict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logger.debug("Starting checkpoint write to %s", path)
|
||||
self._write_future = self._checkpoint_process.write(
|
||||
staging_result, path, **kwargs
|
||||
)
|
||||
logger.info("Checkpoint save to %s initiated", path)
|
||||
|
||||
# Return futures for the staging and writing operations
|
||||
if self._write_future is not None:
|
||||
return wrap_future(staging_result), self._write_future
|
||||
else:
|
||||
# This should not happen since we just assigned _write_future above
|
||||
raise RuntimeError("Write future is unexpectedly None")
|
||||
|
||||
def load(
|
||||
self,
|
||||
path: str,
|
||||
state_dict: Optional[STATE_DICT] = None,
|
||||
*,
|
||||
default_map_location: Any = None,
|
||||
strict: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> STATE_DICT:
|
||||
"""
|
||||
Load a state dictionary from storage.
|
||||
|
||||
Loading is always performed synchronously, even in AsyncCheckpointer.
|
||||
|
||||
Args:
|
||||
path: The path from which to load the checkpoint.
|
||||
state_dict: Optional state dictionary to update with loaded values.
|
||||
If provided, only keys in this dictionary will be loaded.
|
||||
default_map_location: Device mapping function or device name for relocating tensors.
|
||||
strict: If True, raises an error when there are missing keys in the checkpoint.
|
||||
**kwargs: Additional keyword arguments to pass to the reader.
|
||||
|
||||
Returns:
|
||||
The loaded state dictionary.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If strict=True and there are missing keys in the checkpoint.
|
||||
FileNotFoundError: If the checkpoint file is not found.
|
||||
"""
|
||||
logger.info("Loading checkpoint from %s", path)
|
||||
|
||||
loaded_state_dict, missing_keys = self._reader.read(
|
||||
path=path,
|
||||
state_dict=state_dict,
|
||||
map_location=default_map_location,
|
||||
**kwargs,
|
||||
)
|
||||
if strict and missing_keys is not None and missing_keys != []:
|
||||
raise RuntimeError(f"Checkpoint at {path} is missing keys: {missing_keys}")
|
||||
return loaded_state_dict
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the checkpointer and release any resources.
|
||||
|
||||
This method should be called when the checkpointer is no longer needed to ensure
|
||||
proper cleanup of async resources.
|
||||
"""
|
||||
self._checkpoint_stager.close()
|
||||
self._checkpoint_process.close()
|
||||
logger.info("AsyncCheckpointer closed")
|
||||
|
|
|
|||
44
torch/distributed/checkpoint/_experimental/config.py
Normal file
44
torch/distributed/checkpoint/_experimental/config.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Configuration classes for checkpointer construction.
|
||||
|
||||
This module provides configuration dataclasses that consolidate all
|
||||
configuration options needed to construct checkpointers.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .barriers import BarrierConfig
|
||||
from .checkpoint_process import CheckpointProcessConfig
|
||||
from .checkpoint_writer import CheckpointWriterConfig
|
||||
from .staging import CheckpointStagerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointerConfig:
|
||||
"""
|
||||
Configuration class for checkpointer construction.
|
||||
|
||||
This class consolidates the core component configuration options needed to construct
|
||||
a checkpointer, providing a clean separation of concerns where each component
|
||||
manages its own configuration.
|
||||
|
||||
Attributes:
|
||||
writer_config: Configuration options for the checkpoint writer component.
|
||||
barrier_config: Configuration for barrier construction and arguments.
|
||||
staging_config: Configuration options for the async staging component.
|
||||
process_config: Configuration options for the async checkpoint process component.
|
||||
|
||||
"""
|
||||
|
||||
writer_config: CheckpointWriterConfig = field(
|
||||
default_factory=CheckpointWriterConfig
|
||||
)
|
||||
barrier_config: BarrierConfig = field(default_factory=BarrierConfig)
|
||||
|
||||
# Below configs are used for async checkpointing
|
||||
staging_config: CheckpointStagerConfig = field(
|
||||
default_factory=CheckpointStagerConfig
|
||||
)
|
||||
process_config: CheckpointProcessConfig = field(
|
||||
default_factory=CheckpointProcessConfig
|
||||
)
|
||||
216
torch/distributed/checkpoint/_experimental/staging.py
Normal file
216
torch/distributed/checkpoint/_experimental/staging.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""
|
||||
Experimental staging module for PyTorch Distributed Checkpointing.
|
||||
|
||||
This module provides advanced staging capabilities for checkpoints including:
|
||||
- Asynchronous staging using ThreadPoolExecutor
|
||||
- Pinned memory allocation for faster CPU-GPU transfers
|
||||
- Shared memory support for multi-process scenarios
|
||||
- Non-blocking CUDA operations with stream synchronization
|
||||
- Caching of frequently used storages for efficient memory management
|
||||
- Automatic resource cleanup and memory management
|
||||
|
||||
Classes:
|
||||
CheckpointStager: Abstract base class defining the staging interface
|
||||
StagingOptions: Configuration dataclass for staging behavior
|
||||
DefaultStager: Default implementation with comprehensive staging features
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from logging import getLogger
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint._state_dict_stager import StateDictStager
|
||||
|
||||
from .types import STATE_DICT
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class CheckpointStager(abc.ABC):
|
||||
"""
|
||||
Abstract base class for checkpoint staging implementations.
|
||||
|
||||
CheckpointStager defines the interface that all staging implementations
|
||||
must follow. Staging is the process of offloading state dictionaries
|
||||
for async checkpointing.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stage(
|
||||
self,
|
||||
state_dict: STATE_DICT,
|
||||
**kwargs: Any,
|
||||
) -> Union[STATE_DICT, Future[STATE_DICT]]:
|
||||
"""
|
||||
Stage a state dictionary for checkpointing.
|
||||
|
||||
Args:
|
||||
state_dict: The state dictionary to stage
|
||||
**kwargs: Additional staging parameters
|
||||
|
||||
Returns:
|
||||
Either a staged state dictionary (synchronous) or a Future
|
||||
that will resolve to the staged state dictionary (asynchronous)
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Clean up all resources used by the stager.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointStagerConfig:
|
||||
"""
|
||||
Configuration options for checkpoint staging behavior.
|
||||
|
||||
Attributes:
|
||||
use_pinned_memory (bool): Enable pinned memory allocation for faster
|
||||
CPU-GPU transfers. Requires CUDA to be available. Default: True
|
||||
use_shared_memory (bool): Enable shared memory for multi-process
|
||||
scenarios. Useful when multiple processes need access to the
|
||||
same staged data. Default: True
|
||||
use_async_staging (bool): Enable asynchronous staging using a
|
||||
background thread pool. Allows overlapping computation with
|
||||
staging operations. Requires CUDA. Default: True
|
||||
use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory
|
||||
copies with stream synchronization. Improves performance by
|
||||
allowing CPU work to continue during GPU transfers. Default: True
|
||||
|
||||
Note:
|
||||
CUDA-dependent features will raise exception if CUDA is not available.
|
||||
"""
|
||||
|
||||
use_pinned_memory: bool = True
|
||||
use_shared_memory: bool = True
|
||||
use_async_staging: bool = True
|
||||
use_cuda_non_blocking_copy: bool = True
|
||||
|
||||
|
||||
class DefaultStager(CheckpointStager):
|
||||
"""
|
||||
DefaultStager provides a full-featured staging implementation that combines
|
||||
multiple optimization techniques for efficient checkpoint preparation.
|
||||
|
||||
The staging process works as follows:
|
||||
1. State dictionary is submitted for staging (sync or async)
|
||||
2. Tensors are copied from GPU to optimized CPU storage
|
||||
3. CUDA operations are synchronized if non-blocking copies are used
|
||||
4. Staged state dictionary is returned or made available via Future
|
||||
|
||||
NOTE: state_dict should be deep-copyable object as staging will create a
|
||||
copy of it.
|
||||
|
||||
Usage Patterns:
|
||||
# Synchronous staging
|
||||
stager = DefaultStager(CheckpointStagerConfig(use_async_staging=False))
|
||||
staged_dict = stager.stage(state_dict)
|
||||
stager.close()
|
||||
|
||||
# Asynchronous staging
|
||||
stager = DefaultStager(CheckpointStagerConfig(use_async_staging=True))
|
||||
future = stager.stage(state_dict)
|
||||
# ... do other work ...
|
||||
staged_dict = future.result()
|
||||
stager.close()
|
||||
|
||||
# Context manager pattern (recommended)
|
||||
with DefaultStager(config) as stager:
|
||||
result = stager.stage(state_dict)
|
||||
# Automatic cleanup on exit
|
||||
|
||||
Performance Considerations:
|
||||
- Async staging provides best performance when model computation
|
||||
can overlap with staging operations
|
||||
- Pinned memory improves CPU-GPU transfer speeds but uses more memory
|
||||
- Shared memory allows efficient IPC to checkpoint process
|
||||
- Non-blocking copies reduce GPU idle time during memory transfers
|
||||
|
||||
Thread Safety:
|
||||
DefaultStager is not thread-safe. Each thread should use its own
|
||||
instance, or external synchronization should be provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CheckpointStagerConfig = CheckpointStagerConfig(),
|
||||
):
|
||||
self._config = config
|
||||
self._state_dict_stager = StateDictStager(
|
||||
pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory
|
||||
)
|
||||
self._staging_executor = None
|
||||
self._staging_stream = None
|
||||
|
||||
if self._config.use_async_staging:
|
||||
self._staging_executor = ThreadPoolExecutor(max_workers=1)
|
||||
if torch.cuda.is_available():
|
||||
# Note: stream needs to be initialized on the main thread after default cuda
|
||||
# stream is setup/used to avoid the risk of accidentally reusing the main
|
||||
# compute stream or in other cases kernels actually launching from the
|
||||
# main thread.
|
||||
self._staging_stream = torch.cuda.Stream()
|
||||
|
||||
if self._config.use_cuda_non_blocking_copy:
|
||||
assert torch.cuda.is_available(), "Non-blocking copy requires CUDA"
|
||||
|
||||
def stage(
|
||||
self,
|
||||
state_dict: STATE_DICT,
|
||||
**kwargs: Any,
|
||||
) -> Union[STATE_DICT, Future[STATE_DICT]]:
|
||||
if self._config.use_async_staging:
|
||||
assert self._staging_executor is not None, (
|
||||
"Staging executor should be initialized for async staging"
|
||||
)
|
||||
return self._staging_executor.submit(
|
||||
self._stage,
|
||||
state_dict,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self._stage(state_dict, **kwargs)
|
||||
|
||||
def _stage(self, state_dict: STATE_DICT, **kwargs: Any) -> STATE_DICT:
|
||||
state_dict = self._state_dict_stager.stage(
|
||||
state_dict, non_blocking=self._config.use_cuda_non_blocking_copy, **kwargs
|
||||
)
|
||||
|
||||
if self._config.use_cuda_non_blocking_copy:
|
||||
assert self._staging_stream or not self._config.use_async_staging, (
|
||||
"Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized."
|
||||
)
|
||||
|
||||
# waits for the enqued copy operations to finish.
|
||||
self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize()
|
||||
|
||||
return state_dict
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor
|
||||
used for async staging operations and cleans up the underlying StateDictStager's
|
||||
cached storages. Should be called when the stager is no longer needed to prevent
|
||||
resource leaks, especially in long-running applications. After calling close(),
|
||||
the stager should not be used for further staging operations.
|
||||
|
||||
state_dict should be deep-copyable object.
|
||||
|
||||
Example:
|
||||
stager = DefaultStager(CheckpointStagerConfig(use_async_staging=True))
|
||||
# ... do staging operations ...
|
||||
stager.close() # Clean up all resources
|
||||
"""
|
||||
if self._staging_executor:
|
||||
self._staging_executor.shutdown(wait=True)
|
||||
|
||||
self._state_dict_stager.close()
|
||||
42
torch/distributed/checkpoint/_experimental/utils.py
Normal file
42
torch/distributed/checkpoint/_experimental/utils.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""
|
||||
Utility functions for the experimental checkpoint module.
|
||||
|
||||
This module contains helper functions and utilities used across the experimental
|
||||
checkpoint functionality.
|
||||
"""
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Any
|
||||
|
||||
|
||||
def wrap_future(original_result: Any) -> Future[None]:
|
||||
"""
|
||||
Wraps a result (Future or not) to return a Future with None result.
|
||||
|
||||
If the input is a Future, returns a new Future that completes with None when
|
||||
the original Future completes successfully, or propagates any exception.
|
||||
If the input is not a Future, returns a completed Future with None result.
|
||||
|
||||
Args:
|
||||
original_result: The result to wrap (Future or any other value).
|
||||
|
||||
Returns:
|
||||
A Future that completes with None on success or propagates exceptions.
|
||||
"""
|
||||
masked_future: Future[None] = Future()
|
||||
|
||||
if isinstance(original_result, Future):
|
||||
|
||||
def on_complete(_: Future[Any]) -> None:
|
||||
try:
|
||||
original_result.result()
|
||||
masked_future.set_result(None)
|
||||
except Exception as e:
|
||||
masked_future.set_exception(e)
|
||||
|
||||
original_result.add_done_callback(on_complete)
|
||||
else:
|
||||
# Return a completed future with None result
|
||||
masked_future.set_result(None)
|
||||
|
||||
return masked_future
|
||||
|
|
@ -223,6 +223,16 @@ class StateDictStager:
|
|||
|
||||
return y
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Clean up all cached storages and release associated resources.
|
||||
|
||||
This method clears the internal storage cache, allowing garbage collection
|
||||
of cached CPU storages. Any pinned memory associated with cached storages
|
||||
will be automatically unpinned through weak reference finalizers.
|
||||
"""
|
||||
self._cached_storage_mapping.clear()
|
||||
|
||||
@torch.no_grad()
|
||||
def deepcopy_with_tensor_offload(self, x, memo=None, _nil=[], non_blocking=False): # noqa: B006
|
||||
"""Deep copy operation on arbitrary Python objects with special handling for PyTorch tensors.
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
|
@ -157,3 +158,36 @@ def with_temp_dir(
|
|||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def with_checkpoint_logging(
|
||||
func: Optional[Callable] = None,
|
||||
logger_name: str = "torch.distributed.checkpoint",
|
||||
level: int = logging.INFO,
|
||||
) -> Optional[Callable]:
|
||||
"""
|
||||
Wrapper to configure checkpoint logging for distributed tests.
|
||||
|
||||
Args:
|
||||
func: The test function to wrap
|
||||
logger_name: Name of the logger to configure (default: 'torch.distributed.checkpoint')
|
||||
level: Logging level to set (default: logging.INFO)
|
||||
"""
|
||||
assert func is not None
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
|
||||
# Get the logger and store original level
|
||||
target_logger = logging.getLogger(logger_name)
|
||||
original_level = target_logger.level
|
||||
|
||||
# Set the desired logging level
|
||||
target_logger.setLevel(level)
|
||||
|
||||
try:
|
||||
func(self, *args, **kwargs)
|
||||
finally:
|
||||
# Restore original logging level
|
||||
target_logger.setLevel(original_level)
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user