mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`. - At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision). - At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87916 Approved by: https://github.com/mrshenli
24 lines
472 B
Python
24 lines
472 B
Python
from enum import auto, Enum
|
|
|
|
|
|
class TrainingState(Enum):
|
|
"""
|
|
An enum that indicates the state of a ``FullyShardedDataParallel` instance.
|
|
"""
|
|
|
|
IDLE = auto()
|
|
FORWARD_BACKWARD = auto()
|
|
SUMMON_FULL_PARAMS = auto()
|
|
|
|
|
|
class HandleTrainingState(Enum):
|
|
"""
|
|
An enum that indicates the state of a ``FlatParamHandle`.
|
|
"""
|
|
|
|
IDLE = auto()
|
|
FORWARD = auto()
|
|
BACKWARD_PRE = auto()
|
|
BACKWARD_POST = auto()
|
|
SUMMON_FULL_PARAMS = auto()
|