[DeviceMesh] Clean up the call into mesh_resouces to get root mesh (#165787)

We moved the method to get root mesh into class in https://github.com/pytorch/pytorch/pull/164510. This is to further clean code up.

Differential Revision: [D85090191](https://our.internmc.facebook.com/intern/diff/D85090191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165787
Approved by: https://github.com/fegin
This commit is contained in:
fduwjj 2025-10-20 09:36:31 -07:00 committed by PyTorch MergeBot
parent 303c9cf048
commit 7406d2e665
6 changed files with 9 additions and 14 deletions

View File

@ -216,9 +216,7 @@ def replicate(
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True) module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
device_mesh = kwargs.get("device_mesh") device_mesh = kwargs.get("device_mesh")
if device_mesh is not None: if device_mesh is not None:
from torch.distributed.device_mesh import _mesh_resources root_mesh = device_mesh._get_root_mesh()
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
# if a root mesh is not the same as device_mesh, # if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh. # meaning the device_mesh is sliced out from the root mesh.
if root_mesh != device_mesh: if root_mesh != device_mesh:

View File

@ -289,8 +289,8 @@ class FSDPParam:
if self.is_dtensor: if self.is_dtensor:
self._tp_spec = cast(DTensor, param)._spec self._tp_spec = cast(DTensor, param)._spec
dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh) dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh) dp_global_mesh = dp_mesh._get_root_mesh() if dp_mesh is not None else None
tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh) tp_global_mesh = tp_mesh._get_root_mesh() if tp_mesh is not None else None
if dp_global_mesh != tp_global_mesh or ( if dp_global_mesh != tp_global_mesh or (
dp_global_mesh is None or tp_global_mesh is None dp_global_mesh is None or tp_global_mesh is None
): ):

View File

@ -13,7 +13,7 @@ import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
import torch.nn as nn import torch.nn as nn
from torch.distributed.algorithms._comm_hooks import default_hooks from torch.distributed.algorithms._comm_hooks import default_hooks
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._common_utils import ( from torch.distributed.fsdp._common_utils import (
_FSDPDeviceHandle, _FSDPDeviceHandle,
@ -513,7 +513,7 @@ def _init_prefetching_state(
def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
# TODO: we need to add additional check once we support FSDP + PiPPy. # TODO: we need to add additional check once we support FSDP + PiPPy.
# This check is currently sufficient, since we only support FSDP + TP. # This check is currently sufficient, since we only support FSDP + TP.
root_mesh = _mesh_resources.get_root_mesh(device_mesh) root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None
# if a root mesh is not the same as device_mesh, # if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh. # meaning the device_mesh is sliced out from the root mesh.
if device_mesh and root_mesh != state._device_mesh: if device_mesh and root_mesh != state._device_mesh:

View File

@ -16,7 +16,6 @@ from torch.distributed._shard.sharded_tensor import (
Shard, Shard,
ShardedTensor, ShardedTensor,
) )
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.fsdp._common_utils import ( from torch.distributed.fsdp._common_utils import (
_FSDPState, _FSDPState,
_get_module_fsdp_state_if_fully_sharded_module, _get_module_fsdp_state_if_fully_sharded_module,
@ -290,7 +289,7 @@ def _full_pre_state_dict_hook(
``nn.Module``. ``nn.Module``.
""" """
if getattr(fsdp_state, "_device_mesh", False): if getattr(fsdp_state, "_device_mesh", False):
_mesh_resources.get_root_mesh(fsdp_state._device_mesh) fsdp_state._device_mesh._get_root_mesh()
_common_pre_state_dict_hook(module, fsdp_state) _common_pre_state_dict_hook(module, fsdp_state)
_common_unshard_pre_state_dict_hook( _common_unshard_pre_state_dict_hook(
@ -664,7 +663,7 @@ def _sharded_pre_load_state_dict_hook(
if param.device != fsdp_state._device_mesh.device_type: if param.device != fsdp_state._device_mesh.device_type:
param = param.to(fsdp_state._device_mesh.device_type) param = param.to(fsdp_state._device_mesh.device_type)
root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh) root_mesh = fsdp_state._device_mesh._get_root_mesh()
local_tensor = _ext_all_gather_dtensor( local_tensor = _ext_all_gather_dtensor(
param, root_mesh, fsdp_state._fsdp_extension param, root_mesh, fsdp_state._fsdp_extension
) )

View File

@ -14,7 +14,6 @@ from torch.distributed._shard.sharded_tensor import (
) )
from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
@ -229,7 +228,7 @@ def _chunk_dtensor(
The local rank will gets its corresponding chunk as the local tensor to create a DTensor. The local rank will gets its corresponding chunk as the local tensor to create a DTensor.
""" """
root_mesh = _mesh_resources.get_root_mesh(device_mesh) root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None
if root_mesh is None: if root_mesh is None:
raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.") raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.")
if root_mesh.ndim < 2: if root_mesh.ndim < 2:

View File

@ -701,9 +701,8 @@ class DistributedDataParallel(Module, Joinable):
) )
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.process_group = device_mesh.get_group(mesh_dim=0) self.process_group = device_mesh.get_group(mesh_dim=0)
from torch.distributed.device_mesh import _mesh_resources
root_mesh = _mesh_resources.get_root_mesh(device_mesh) root_mesh = device_mesh._get_root_mesh()
# if a root mesh is not the same as device_mesh, # if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh. # meaning the device_mesh is sliced out from the root mesh.
if root_mesh != device_mesh: if root_mesh != device_mesh: