[DeviceMesh] Remove slicing submesh warning messages and clean up in fsdp params (#166466)

Differential Revision: [D85735294](https://our.internmc.facebook.com/intern/diff/D85735294)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166466
Approved by: https://github.com/fegin
This commit is contained in:
fduwjj 2025-10-28 16:22:52 -07:00 committed by PyTorch MergeBot
parent a186aa8d6c
commit f02708c2be
2 changed files with 1 additions and 9 deletions

View File

@ -764,12 +764,6 @@ else:
"""
slice_from_root = True
if self != self._get_root_mesh():
warnings.warn(
"You are attempting to slice a submesh from another submesh. While we support this operation, "
"it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. "
"If not, this may result in some ranks receiving the submesh while others encounter errors.",
stacklevel=2,
)
slice_from_root = False
# The slice mesh_dim_names should consist either the current device_mesh's mesh_dim_names

View File

@ -832,9 +832,7 @@ class FSDPParam:
if mesh.mesh_dim_names is None:
raise AssertionError("Expected mesh_dim_names to not be None")
shard_dim_name = mesh.mesh_dim_names[-1]
root_mesh = mesh._get_root_mesh()
return root_mesh[shard_dim_name]
return mesh[shard_dim_name]
def _assert_in_states(self, *states: ShardedState) -> None:
if self.sharded_state not in states: