mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DCP] Add logging for _stateful_to_state_dict(), stage_state_dict(), and synchronize_staging() (#151320)
Summary: As titled. Test Plan: CI Differential Revision: D73040700 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151320 Approved by: https://github.com/saumishr
This commit is contained in:
parent
c5b10ff119
commit
473a38b562
|
|
@ -245,20 +245,25 @@ def async_save(
|
||||||
pg = process_group or _get_default_group()
|
pg = process_group or _get_default_group()
|
||||||
assert (
|
assert (
|
||||||
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
|
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
|
||||||
), (
|
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
||||||
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
|
||||||
)
|
|
||||||
|
|
||||||
storage_writer = cast(
|
storage_writer = cast(
|
||||||
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
|
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
state_dict = _stateful_to_state_dict(state_dict)
|
state_dict = _stateful_to_state_dict(state_dict)
|
||||||
if isinstance(storage_writer, AsyncStager):
|
|
||||||
staged_state_dict = storage_writer.stage(state_dict)
|
@_dcp_method_logger(log_exceptions=True)
|
||||||
else: # provides bwc for storage_writers not implementing AsyncStager
|
def stage_state_dict():
|
||||||
staged_state_dict = _create_cpu_state_dict(state_dict)
|
if isinstance(storage_writer, AsyncStager):
|
||||||
_copy_state_dict(state_dict, staged_state_dict, type_check=False)
|
staged_state_dict = storage_writer.stage(state_dict)
|
||||||
|
else: # provides bwc for storage_writers not implementing AsyncStager
|
||||||
|
staged_state_dict = _create_cpu_state_dict(state_dict)
|
||||||
|
_copy_state_dict(state_dict, staged_state_dict, type_check=False)
|
||||||
|
|
||||||
|
return staged_state_dict
|
||||||
|
|
||||||
|
staged_state_dict = stage_state_dict()
|
||||||
|
|
||||||
executor: _AsyncCheckpointExecutor = (
|
executor: _AsyncCheckpointExecutor = (
|
||||||
_ProcessBasedAsyncCheckpointExecutor()
|
_ProcessBasedAsyncCheckpointExecutor()
|
||||||
|
|
@ -274,15 +279,20 @@ def async_save(
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
@_dcp_method_logger(log_exceptions=True)
|
||||||
isinstance(storage_writer, AsyncStager)
|
def maybe_synchronize_staging():
|
||||||
and storage_writer.should_synchronize_after_execute
|
if (
|
||||||
):
|
isinstance(storage_writer, AsyncStager)
|
||||||
storage_writer.synchronize_staging()
|
and storage_writer.should_synchronize_after_execute
|
||||||
|
):
|
||||||
|
storage_writer.synchronize_staging()
|
||||||
|
|
||||||
|
maybe_synchronize_staging()
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@_dcp_method_logger(log_exceptions=True)
|
||||||
def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
||||||
"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
|
"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
|
||||||
stateful_state_dict = {}
|
stateful_state_dict = {}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user