[DeviceMesh] Clean up the call into mesh_resouces to get root mesh (#165787)

We moved the method to get root mesh into class in https://github.com/pytorch/pytorch/pull/164510. This is to further clean code up.

Differential Revision: [D85090191](https://our.internmc.facebook.com/intern/diff/D85090191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165787
Approved by: https://github.com/fegin
This commit is contained in:
fduwjj 2025-10-20 09:36:31 -07:00 committed by PyTorch MergeBot
parent 303c9cf048
commit 7406d2e665
6 changed files with 9 additions and 14 deletions

View File

@ -216,9 +216,7 @@ def replicate(
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
device_mesh = kwargs.get("device_mesh")
if device_mesh is not None:
from torch.distributed.device_mesh import _mesh_resources
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
root_mesh = device_mesh._get_root_mesh()
# if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh.
if root_mesh != device_mesh:

View File

@ -289,8 +289,8 @@ class FSDPParam:
if self.is_dtensor:
self._tp_spec = cast(DTensor, param)._spec
dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh)
tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh)
dp_global_mesh = dp_mesh._get_root_mesh() if dp_mesh is not None else None
tp_global_mesh = tp_mesh._get_root_mesh() if tp_mesh is not None else None
if dp_global_mesh != tp_global_mesh or (
dp_global_mesh is None or tp_global_mesh is None
):

View File

@ -13,7 +13,7 @@ import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
import torch.nn as nn
from torch.distributed.algorithms._comm_hooks import default_hooks
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._common_utils import (
_FSDPDeviceHandle,
@ -513,7 +513,7 @@ def _init_prefetching_state(
def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
# TODO: we need to add additional check once we support FSDP + PiPPy.
# This check is currently sufficient, since we only support FSDP + TP.
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None
# if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh.
if device_mesh and root_mesh != state._device_mesh:

View File

@ -16,7 +16,6 @@ from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
)
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.fsdp._common_utils import (
_FSDPState,
_get_module_fsdp_state_if_fully_sharded_module,
@ -290,7 +289,7 @@ def _full_pre_state_dict_hook(
``nn.Module``.
"""
if getattr(fsdp_state, "_device_mesh", False):
_mesh_resources.get_root_mesh(fsdp_state._device_mesh)
fsdp_state._device_mesh._get_root_mesh()
_common_pre_state_dict_hook(module, fsdp_state)
_common_unshard_pre_state_dict_hook(
@ -664,7 +663,7 @@ def _sharded_pre_load_state_dict_hook(
if param.device != fsdp_state._device_mesh.device_type:
param = param.to(fsdp_state._device_mesh.device_type)
root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh)
root_mesh = fsdp_state._device_mesh._get_root_mesh()
local_tensor = _ext_all_gather_dtensor(
param, root_mesh, fsdp_state._fsdp_extension
)

View File

@ -14,7 +14,6 @@ from torch.distributed._shard.sharded_tensor import (
)
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
@ -229,7 +228,7 @@ def _chunk_dtensor(
The local rank will gets its corresponding chunk as the local tensor to create a DTensor.
"""
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None
if root_mesh is None:
raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.")
if root_mesh.ndim < 2:

View File

@ -701,9 +701,8 @@ class DistributedDataParallel(Module, Joinable):
)
self.device_mesh = device_mesh
self.process_group = device_mesh.get_group(mesh_dim=0)
from torch.distributed.device_mesh import _mesh_resources
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
root_mesh = device_mesh._get_root_mesh()
# if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh.
if root_mesh != device_mesh: