[DCP] Add logging for _stateful_to_state_dict(), stage_state_dict(), and synchronize_staging() (#151320)

Summary: As titled.

Test Plan: CI

Differential Revision: D73040700

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151320
Approved by: https://github.com/saumishr
This commit is contained in:
Meet Vadakkanchery 2025-04-17 01:08:32 +00:00 committed by Nikita Shulga
parent c5b10ff119
commit 473a38b562

View File

@ -245,21 +245,26 @@ def async_save(
pg = process_group or _get_default_group()
assert (
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
), (
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
)
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
)
state_dict = _stateful_to_state_dict(state_dict)
@_dcp_method_logger(log_exceptions=True)
def stage_state_dict():
if isinstance(storage_writer, AsyncStager):
staged_state_dict = storage_writer.stage(state_dict)
else: # provides bwc for storage_writers not implementing AsyncStager
staged_state_dict = _create_cpu_state_dict(state_dict)
_copy_state_dict(state_dict, staged_state_dict, type_check=False)
return staged_state_dict
staged_state_dict = stage_state_dict()
executor: _AsyncCheckpointExecutor = (
_ProcessBasedAsyncCheckpointExecutor()
if async_checkpointer_type == AsyncCheckpointerType.PROCESS
@ -274,15 +279,20 @@ def async_save(
process_group=process_group,
)
@_dcp_method_logger(log_exceptions=True)
def maybe_synchronize_staging():
if (
isinstance(storage_writer, AsyncStager)
and storage_writer.should_synchronize_after_execute
):
storage_writer.synchronize_staging()
maybe_synchronize_staging()
return f
@_dcp_method_logger(log_exceptions=True)
def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
stateful_state_dict = {}