mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DCP] Add logging for _stateful_to_state_dict(), stage_state_dict(), and synchronize_staging() (#151320)
Summary: As titled. Test Plan: CI Differential Revision: D73040700 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151320 Approved by: https://github.com/saumishr
This commit is contained in:
parent
c5b10ff119
commit
473a38b562
|
|
@ -245,21 +245,26 @@ def async_save(
|
|||
pg = process_group or _get_default_group()
|
||||
assert (
|
||||
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
|
||||
), (
|
||||
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
||||
)
|
||||
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
||||
|
||||
storage_writer = cast(
|
||||
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
|
||||
)
|
||||
|
||||
state_dict = _stateful_to_state_dict(state_dict)
|
||||
|
||||
@_dcp_method_logger(log_exceptions=True)
|
||||
def stage_state_dict():
|
||||
if isinstance(storage_writer, AsyncStager):
|
||||
staged_state_dict = storage_writer.stage(state_dict)
|
||||
else: # provides bwc for storage_writers not implementing AsyncStager
|
||||
staged_state_dict = _create_cpu_state_dict(state_dict)
|
||||
_copy_state_dict(state_dict, staged_state_dict, type_check=False)
|
||||
|
||||
return staged_state_dict
|
||||
|
||||
staged_state_dict = stage_state_dict()
|
||||
|
||||
executor: _AsyncCheckpointExecutor = (
|
||||
_ProcessBasedAsyncCheckpointExecutor()
|
||||
if async_checkpointer_type == AsyncCheckpointerType.PROCESS
|
||||
|
|
@ -274,15 +279,20 @@ def async_save(
|
|||
process_group=process_group,
|
||||
)
|
||||
|
||||
@_dcp_method_logger(log_exceptions=True)
|
||||
def maybe_synchronize_staging():
|
||||
if (
|
||||
isinstance(storage_writer, AsyncStager)
|
||||
and storage_writer.should_synchronize_after_execute
|
||||
):
|
||||
storage_writer.synchronize_staging()
|
||||
|
||||
maybe_synchronize_staging()
|
||||
|
||||
return f
|
||||
|
||||
|
||||
@_dcp_method_logger(log_exceptions=True)
|
||||
def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
||||
"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
|
||||
stateful_state_dict = {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user