pytorch/torch/distributed/fsdp
Rohan Varma a0b3814433 Clean prefixes when searching for params / buffers to ignore (#78278)
Co-authored with: @awgu

When `state_dict` has a prefix attached to it, the current logic for ignoring parameters and buffers does not work since it doesn't account for this prefix. To fix this, we make the following changes:

- clean the key if it starts with prefix. Note that all keys may not start with prefix, i.e. if the current module's state_dict_post_hook is running and previous module `state_dict` has already been computed and previous module is on the same level of hierarchy as the current module.
- This prefixing makes it so that it is not current to override child module's ignored params and buffers with the root FSDP instance's (this wouldn't work if child FSDP instances had ignored modules, and root didn't, for example). We fix this by having each parent know about the ignored modules of their children, and computing fully qualified names for ignored params and buffers.
- This means that each for a particular FSDP instance, that instance knows about the names of itself and its children (in fully qualified form) that it needs to ignore. It wouldn't know about parent ignored params and buffers, but it doesn't need to store this data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78278
Approved by: https://github.com/awgu
2022-05-26 02:43:03 +00:00
..
__init__.py [FSDP] Full state_dict rank0 only and CPU offload 2022-04-21 13:28:18 +00:00
_optim_utils.py [FSDP] Remove unneeded padding logic for optim state dict 2022-05-25 17:22:03 +00:00
_utils.py CheckpointWrapper state_dict fix (#77224) 2022-05-17 03:39:31 +00:00
flatten_params_wrapper.py CheckpointWrapper state_dict fix (#77224) 2022-05-17 03:39:31 +00:00
fully_sharded_data_parallel.py Clean prefixes when searching for params / buffers to ignore (#78278) 2022-05-26 02:43:03 +00:00
shard_utils.py [FSDP] Implement sharded_state_dict and load_sharded_state_dict 2022-05-15 22:48:56 +00:00
sharded_grad_scaler.py [FSDP] Sharded Grad Scaler (#76918) 2022-05-16 15:53:21 +00:00
wrap.py [Reland] Mixed precision batchnorm fix (#77234) 2022-05-11 15:03:34 +00:00