For both exposing the original parameters and for TP integration, we cannot only rely on `isinstance(param, FlatParameter)` to ignore already-flattened parameters in `.named_parameters()`. As a simple workaround, we can mark original parameters or `ShardedTensor`s with an attribute `_fsdp_flattened` (saved as a string variable `FSDP_FLATTENED`) to indicate that the parameter/tensor has already been flattened. This issue only arises for recursive/nested FSDP wrapping.
This PR also changes `isinstance(param, FlatParameter)` checks to `type(param) is FlatParameter` because all tensor subclasses that have `_is_param == True` will return `True` for `isinstance(param, <any subclass with _is_param == True>)`. This means that a `ShardedTensor` parameter will return `True` for `isinstance(st, FlatParameter)`, which is not what we want.
5271494ef2/torch/nn/parameter.py (L8-L10)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85038
Approved by: https://github.com/rohan-varma
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.
This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.
Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao