pytorch/torch/distributed/fsdp/_common_utils.py
Andrew Gu e667c00656 [FSDP()][2/N] Refactor training state (#87916)
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`.
- At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision).
- At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87916
Approved by: https://github.com/mrshenli
2022-10-29 06:50:30 +00:00

24 lines
472 B
Python

from enum import auto, Enum
class TrainingState(Enum):
"""
An enum that indicates the state of a ``FullyShardedDataParallel` instance.
"""
IDLE = auto()
FORWARD_BACKWARD = auto()
SUMMON_FULL_PARAMS = auto()
class HandleTrainingState(Enum):
"""
An enum that indicates the state of a ``FlatParamHandle`.
"""
IDLE = auto()
FORWARD = auto()
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
SUMMON_FULL_PARAMS = auto()