mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
`state_dict` is a very common variable name people use to represent a local state_dict and `load_state_dict` conflicts with DCP's `load_state_dict`. This PR changes `state_dict` to `get_state_dict`. `get_state_dict` is more close to what is this API does -- users use the API to get the current state_dict for saving or for loading (passed to DCP for loading in-place).. This PR also changes `load_state_dict` to `set_state_dict`. `set_state_dict` is less ideal compared to `get_state_dict` but is symetric. We can still change the API name before it goes to beta. This PR also simplies the API signatures. `model_only` is removed and `optim_only` only exists for `get_state_dict`. Differential Revision: [D50213931](https://our.internmc.facebook.com/intern/diff/D50213931/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/111120 Approved by: https://github.com/wz337 ghstack dependencies: #111106, #111107, #111275, #111109, #111110
83 lines
2.6 KiB
ReStructuredText
83 lines
2.6 KiB
ReStructuredText
.. role:: hidden
|
|
:class: hidden-section
|
|
|
|
Distributed Checkpoint - torch.distributed.checkpoint
|
|
=====================================================
|
|
|
|
|
|
Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel.
|
|
It handles load-time resharding which enables saving in one cluster topology and loading into another.
|
|
|
|
DCP is different than `torch.save` and `torch.load` in a few significant ways:
|
|
|
|
* It produces multiple files per checkpoint, with at least one per rank.
|
|
* It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
|
|
|
|
The entrypoints to load and save a checkpoint are the following:
|
|
|
|
|
|
.. automodule:: torch.distributed.checkpoint
|
|
|
|
.. currentmodule:: torch.distributed.checkpoint
|
|
|
|
.. autofunction:: load_state_dict
|
|
.. autofunction:: save_state_dict
|
|
|
|
This `example <https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py>`_ shows how to use Pytorch Distributed Checkpoint to save a FSDP model.
|
|
|
|
|
|
The following types define the IO interface used during checkpoint:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.StorageReader
|
|
:members:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.StorageWriter
|
|
:members:
|
|
|
|
The following types define the planner interface used during checkpoint:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.LoadPlanner
|
|
:members:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.LoadPlan
|
|
:members:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.ReadItem
|
|
:members:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.SavePlanner
|
|
:members:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.SavePlan
|
|
:members:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.WriteItem
|
|
:members:
|
|
|
|
We provide a filesystem based storage layer:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.FileSystemReader
|
|
:members:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.FileSystemWriter
|
|
:members:
|
|
|
|
We provide default implementations of `LoadPlanner` and `SavePlanner` that
|
|
can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.DefaultSavePlanner
|
|
:members:
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.DefaultLoadPlanner
|
|
:members:
|
|
|
|
We provide a set of APIs to help users do get and set state_dict easily. This is
|
|
an experimental feature and is subject to change.
|
|
|
|
.. autofunction:: torch.distributed.checkpoint.state_dict.get_state_dict
|
|
|
|
.. autofunction:: torch.distributed.checkpoint.state_dict.set_state_dict
|
|
|
|
.. autoclass:: torch.distributed.checkpoint.state_dict.StateDictOptions
|
|
:members:
|