Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
* #124944
* #124939
* __->__ #122965
Differential Revision: [D55493240](https://our.internmc.facebook.com/intern/diff/D55493240/)
*This PR is now ready for merge and is not an RFC*
Major choices are:
-- the introduction of the AsyncStager protocol
-- removed `executor` from param.
-- leave async as a separate method (for now)
This proposal seeks to add extension points to dcp.async_save, allowing users to:
- Specify a specific staging method when calling async_save
- Allow a vehicle for also making the staging method async, to allow for cases where we may want to overlap with the training loop (e.g., overlap d2h with and only synchronize at the optim.step)
- Potentially specify the execution method for doing async_save in parallel. For example some users may prefer a subprocess over a thread to avoid GIL issues.
A totally reasonable alternative to this entire proposal is to expect users who want this level of customization
to write their own custom async save methods. Here's an example which addresses the issues mentioned
in PR comments.
```
def custom_async_save(...):
# this step accomplishes staging and includes the usual 'planning' calls (issue 1)
buffered_writer = CpuBufferedWriter() # this is stateful, contains a copy of state_dict
dcp.save(state_dict, storage_writer=buffered_writer)
final_storage_writer = FileSystemWriter()
mp.spawn( # issue2 is gone, do whatever you want here
dcp.save, # or some custom sub-process method which calls dcp.save under the hood
buffered_writer.state_dict, # lot's of way's to do this, not really the most important part
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
process_group=process_group, # this actually wouldn't work, but again not the pt.
)
# leaving out the rest of the details for managing your extra special subprocess.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122965
Approved by: https://github.com/daulet-askarov
Enables the deduplication of saved entries by load balancing duplicates across ranks.
Tested with existing and modified tests. Additionally tested with the following code snippet, which saves a 20GB DDP model in **~3 seconds on 8 ranks**. Before this PR, the same operation has been measured at ~19 seconds.
```
def run(local_rank, world_size, param_size, num_params, work_dir):
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", rank=local_rank, world_size=world_size)
model = Model(param_size=param_size, num_params=num_params)
model = DistributedDataParallel(model, gradient_as_bucket_view=True)
_patch_model_state_dict(model)
sz = sum(t.nelement() * t.element_size() for t in model.parameters())
rank_0_print(f"Model size: {sz / 1_000_000_000.0} GB")
rank_0_print("Saving the model with DCP...")
checkpointer = _FileSystemCheckpointer(
f"{args.work_dir}/dcp",
sync_files=False,
single_file_per_rank=False,
thread_count=1
)
begin_ts = time.monotonic()
checkpointer.save(state_dict={"model": model})
end_ts = time.monotonic()
rank_0_print(f"Took {end_ts - begin_ts} seconds with DCP")
```
Differential Revision: [D52435926](https://our.internmc.facebook.com/intern/diff/D52435926/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116469
Approved by: https://github.com/fegin, https://github.com/wz337
Adds a useful high level wrapper for calling `dist.save/load` with the correct storage readers and writers.
Instead of doing:
```
DCP.save(
state_dict={...},
storage_writer=StorageWriter(...)
)
DCP.load(
state_dict={...},
storage_reader=StorageReader(...)
)
```
We can now do:
```
checkpointer = Checkpointer(...)
checkpointer.save(state_dict={...})
checkpointer.load(state_dict={...})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114603
Approved by: https://github.com/fegin, https://github.com/wz337
`state_dict` is a very common variable name people use to represent a local
state_dict and `load_state_dict` conflicts with DCP's `load_state_dict`.
This PR changes `state_dict` to `get_state_dict`. `get_state_dict` is more close to what is this API does -- users use the API to get the current state_dict for saving or for loading (passed to DCP for loading in-place)..
This PR also changes `load_state_dict` to `set_state_dict`. `set_state_dict` is less ideal compared to `get_state_dict` but is symetric. We can still change the API name before it goes to beta.
This PR also simplies the API signatures. `model_only` is removed and `optim_only` only exists for `get_state_dict`.
Differential Revision: [D50213931](https://our.internmc.facebook.com/intern/diff/D50213931/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111120
Approved by: https://github.com/wz337
ghstack dependencies: #111106, #111107, #111275, #111109, #111110