[FSDP][Easy] Move _FSDPState attrs to avoid comment confusion (#106392)

Resubmit of https://github.com/pytorch/pytorch/pull/106333 after rebasing (I lost the original branch locally)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106392
Approved by: https://github.com/kwen2501
This commit is contained in:
Andrew Gu 2023-08-01 16:22:43 +00:00 committed by PyTorch MergeBot
parent 5c3aae8385
commit 506b55fc29

View File

@ -114,6 +114,7 @@ class _FSDPState(_State):
self.process_group: Optional[dist.ProcessGroup] = None
self.rank: int = -1
self.world_size: int = -1
self._device_mesh: Optional[DeviceMesh] = None
self.sharding_strategy = ShardingStrategy.FULL_SHARD
self._use_orig_params: bool = False
self.training_state = TrainingState.IDLE
@ -127,6 +128,10 @@ class _FSDPState(_State):
nn.Module, Optional[flat_param_file.FlatParamHandle]
] = {}
self.compute_device: Optional[torch.device] = None
self._gradient_predivide_factor: int = 0
self._gradient_postdivide_factor: int = 0
self._comm_hook: Optional[Callable] = None
self._comm_hook_state: Optional[Any] = None
# Abstract device handle for fsdp compute device. For now,
# the compute device must implement cuda semantics used by fsdp
self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle()
@ -134,11 +139,6 @@ class _FSDPState(_State):
# Save these static lists to avoid the repeated tree traversals
self._all_fsdp_states: List[_FSDPState] = []
self._all_handles: List[flat_param_file.FlatParamHandle] = []
self._gradient_predivide_factor: int = 0
self._gradient_postdivide_factor: int = 0
self._comm_hook: Optional[Callable] = None
self._comm_hook_state: Optional[Any] = None
self._device_mesh: Optional[DeviceMesh] = None
def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]: