pytorch/torch/distributed/algorithms
Rohan Varma f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
..
_checkpoint CheckpointWrapper state_dict fix (#77224) 2022-05-17 03:39:31 +00:00
_optimizer_overlap make fsdp folder to be public (#72084) 2022-02-02 15:50:14 +00:00
ddp_comm_hooks [Model Averaging] Support disabling post-local gradient sync (#76723) 2022-05-16 18:09:09 +00:00
model_averaging [Model Averaging] Make an error message more clear in hierarchical_model_averager.py 2022-04-26 15:20:51 +00:00
quantization [BE] minor improvement to dist quantization (#67401) 2021-11-21 23:31:59 -08:00
__init__.py Make _Join, _Joinable, _JoinHook public (#62605) 2021-08-03 12:20:11 -07:00
join.py [Join][BE] Fix typo; remove obsolete method (#72886) 2022-02-16 15:03:09 +00:00