mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DeviceMesh] Clean up the call into mesh_resouces to get root mesh (#165787)
We moved the method to get root mesh into class in https://github.com/pytorch/pytorch/pull/164510. This is to further clean code up. Differential Revision: [D85090191](https://our.internmc.facebook.com/intern/diff/D85090191) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165787 Approved by: https://github.com/fegin
This commit is contained in:
parent
303c9cf048
commit
7406d2e665
|
|
@ -216,9 +216,7 @@ def replicate(
|
|||
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
|
||||
device_mesh = kwargs.get("device_mesh")
|
||||
if device_mesh is not None:
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
|
||||
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
||||
root_mesh = device_mesh._get_root_mesh()
|
||||
# if a root mesh is not the same as device_mesh,
|
||||
# meaning the device_mesh is sliced out from the root mesh.
|
||||
if root_mesh != device_mesh:
|
||||
|
|
|
|||
|
|
@ -289,8 +289,8 @@ class FSDPParam:
|
|||
if self.is_dtensor:
|
||||
self._tp_spec = cast(DTensor, param)._spec
|
||||
dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
|
||||
dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh)
|
||||
tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh)
|
||||
dp_global_mesh = dp_mesh._get_root_mesh() if dp_mesh is not None else None
|
||||
tp_global_mesh = tp_mesh._get_root_mesh() if tp_mesh is not None else None
|
||||
if dp_global_mesh != tp_global_mesh or (
|
||||
dp_global_mesh is None or tp_global_mesh is None
|
||||
):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import torch.distributed.fsdp._traversal_utils as traversal_utils
|
|||
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
|
||||
import torch.nn as nn
|
||||
from torch.distributed.algorithms._comm_hooks import default_hooks
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
_FSDPDeviceHandle,
|
||||
|
|
@ -513,7 +513,7 @@ def _init_prefetching_state(
|
|||
def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
|
||||
# TODO: we need to add additional check once we support FSDP + PiPPy.
|
||||
# This check is currently sufficient, since we only support FSDP + TP.
|
||||
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
||||
root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None
|
||||
# if a root mesh is not the same as device_mesh,
|
||||
# meaning the device_mesh is sliced out from the root mesh.
|
||||
if device_mesh and root_mesh != state._device_mesh:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from torch.distributed._shard.sharded_tensor import (
|
|||
Shard,
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
_FSDPState,
|
||||
_get_module_fsdp_state_if_fully_sharded_module,
|
||||
|
|
@ -290,7 +289,7 @@ def _full_pre_state_dict_hook(
|
|||
``nn.Module``.
|
||||
"""
|
||||
if getattr(fsdp_state, "_device_mesh", False):
|
||||
_mesh_resources.get_root_mesh(fsdp_state._device_mesh)
|
||||
fsdp_state._device_mesh._get_root_mesh()
|
||||
|
||||
_common_pre_state_dict_hook(module, fsdp_state)
|
||||
_common_unshard_pre_state_dict_hook(
|
||||
|
|
@ -664,7 +663,7 @@ def _sharded_pre_load_state_dict_hook(
|
|||
if param.device != fsdp_state._device_mesh.device_type:
|
||||
param = param.to(fsdp_state._device_mesh.device_type)
|
||||
|
||||
root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh)
|
||||
root_mesh = fsdp_state._device_mesh._get_root_mesh()
|
||||
local_tensor = _ext_all_gather_dtensor(
|
||||
param, root_mesh, fsdp_state._fsdp_extension
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from torch.distributed._shard.sharded_tensor import (
|
|||
)
|
||||
from torch.distributed._shard.sharding_spec import ShardMetadata
|
||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
|
||||
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
|
||||
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
||||
|
|
@ -229,7 +228,7 @@ def _chunk_dtensor(
|
|||
|
||||
The local rank will gets its corresponding chunk as the local tensor to create a DTensor.
|
||||
"""
|
||||
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
||||
root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None
|
||||
if root_mesh is None:
|
||||
raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.")
|
||||
if root_mesh.ndim < 2:
|
||||
|
|
|
|||
|
|
@ -701,9 +701,8 @@ class DistributedDataParallel(Module, Joinable):
|
|||
)
|
||||
self.device_mesh = device_mesh
|
||||
self.process_group = device_mesh.get_group(mesh_dim=0)
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
|
||||
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
||||
root_mesh = device_mesh._get_root_mesh()
|
||||
# if a root mesh is not the same as device_mesh,
|
||||
# meaning the device_mesh is sliced out from the root mesh.
|
||||
if root_mesh != device_mesh:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user