diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index d16c10783c9..8c6f3326f83 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -245,20 +245,25 @@ def async_save( pg = process_group or _get_default_group() assert ( torch.device("cpu") in pg._device_types # type: ignore[attr-defined] - ), ( - "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" - ) + ), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" storage_writer = cast( StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) ) state_dict = _stateful_to_state_dict(state_dict) - if isinstance(storage_writer, AsyncStager): - staged_state_dict = storage_writer.stage(state_dict) - else: # provides bwc for storage_writers not implementing AsyncStager - staged_state_dict = _create_cpu_state_dict(state_dict) - _copy_state_dict(state_dict, staged_state_dict, type_check=False) + + @_dcp_method_logger(log_exceptions=True) + def stage_state_dict(): + if isinstance(storage_writer, AsyncStager): + staged_state_dict = storage_writer.stage(state_dict) + else: # provides bwc for storage_writers not implementing AsyncStager + staged_state_dict = _create_cpu_state_dict(state_dict) + _copy_state_dict(state_dict, staged_state_dict, type_check=False) + + return staged_state_dict + + staged_state_dict = stage_state_dict() executor: _AsyncCheckpointExecutor = ( _ProcessBasedAsyncCheckpointExecutor() @@ -274,15 +279,20 @@ def async_save( process_group=process_group, ) - if ( - isinstance(storage_writer, AsyncStager) - and storage_writer.should_synchronize_after_execute - ): - storage_writer.synchronize_staging() + @_dcp_method_logger(log_exceptions=True) + def maybe_synchronize_staging(): + if ( + isinstance(storage_writer, AsyncStager) + and storage_writer.should_synchronize_after_execute + ): + storage_writer.synchronize_staging() + + maybe_synchronize_staging() return f +@_dcp_method_logger(log_exceptions=True) def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: """Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object.""" stateful_state_dict = {}