Co-authored with: @awgu
When `state_dict` has a prefix attached to it, the current logic for ignoring parameters and buffers does not work since it doesn't account for this prefix. To fix this, we make the following changes:
- clean the key if it starts with prefix. Note that all keys may not start with prefix, i.e. if the current module's state_dict_post_hook is running and previous module `state_dict` has already been computed and previous module is on the same level of hierarchy as the current module.
- This prefixing makes it so that it is not current to override child module's ignored params and buffers with the root FSDP instance's (this wouldn't work if child FSDP instances had ignored modules, and root didn't, for example). We fix this by having each parent know about the ignored modules of their children, and computing fully qualified names for ignored params and buffers.
- This means that each for a particular FSDP instance, that instance knows about the names of itself and its children (in fully qualified form) that it needs to ignore. It wouldn't know about parent ignored params and buffers, but it doesn't need to store this data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78278
Approved by: https://github.com/awgu
After offline discussion, decided that by default moving CPU module to GPU is a bit too risky due to possible OOM during init issue.
Theoretically, we should not OOM because it is required for module that is being wrapped by FSDP to fit into GPU, i.e. during forward. But possibly can be temporary GPU tensors etc allocated during __init___ that break this assumption, it is better for now to allow users a way to init on CPU if needed.
We still warn to use `device_id` for faster init if model is on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77720
Approved by: https://github.com/zhaojuanmao
Introduce error handling across all ranks when loading and saving checkpoints.
This makes it a lot simpler for users to handle failures and, as a positive side-effect, coordination of when it successfully finished.
This change requires 3 collectives when saving and 1 when loading.
All those collectives carry a small payload so they will be latency bound and write time should dominate it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77091
Approved by: https://github.com/pritamdamania87, https://github.com/wanchaol
This PR does a number of things:
- Move linalg.vector_norm to structured kernels and simplify the logic
- Fixes a number of prexisting issues with the dtype kwarg of these ops
- Heavily simplifies and corrects the logic of `linalg.matrix_norm` and `linalg.norm` to be consistent with the docs
- Before the `_out` versions of these functions were incorrect
- Their implementation is now as efficient as expected, as it avoids reimplementing these operations whenever possible.
- Deprecates `torch.frobenius_norm` and `torch.nuclear_norm`, as they were exposed in the API and they are apparently being used in mobile (??!!) even though they were not documented and their implementation was slow.
- I'd love to get rid of these functions already, but I guess we have to go through their deprecation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76547
Approved by: https://github.com/mruberry
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.
This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.
Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
gather_object is problematic when used with Tensors as they can unpickle on the wrong
device and lead to deadlocks or spurious failures.
This change introduces a RPC workaround for EFA when initing TensorPipe until
they properly address it.
Fixes#73935
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77272
Approved by: https://github.com/pritamdamania87
I find that sometimes disabling intra-subgroup gradient allreduce can still give a satisfying accuracy for some cases, so better to make such gradient averaging configurable. This does not take into account the saving in the communication of allreducing gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76723
Approved by: https://github.com/rohan-varma
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77356
Implement ShardedTensor compatible sharded_state_dict() and load_sharded_state_dict().
Algorithm overview:
sharded_state_dict():
1. Call summon_full_parameters().
2. For each unflattened, non-sharded parameter.
2.1 Call chunk() to get the local shard of the parameter.
2.2 Create a ShardedTensor.
3. Replace the tensor in the state_dict with the newly created ShardedTensor.
load_sharded_state_dict():
1. For each unflattened, sharded parameter (ShardedTensor) in the given state_dict:
1.1 Pop out from the state_dict.
1.2 Do allgather to reconstruct the unflattened, non-sharded parameter.
2. Create a FlatParameter with the unflattened, non-sharded parameters.
3. Shard the newly created FlatParameter.
4. Insert the new FlatParameter into the state_dict.
Differential Revision: [D36284983](https://our.internmc.facebook.com/intern/diff/D36284983/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36284983/)!
Approved by: https://github.com/zhaojuanmao