mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DeviceMesh] Remove slicing submesh warning messages and clean up in fsdp params (#166466)
Differential Revision: [D85735294](https://our.internmc.facebook.com/intern/diff/D85735294) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166466 Approved by: https://github.com/fegin
This commit is contained in:
parent
a186aa8d6c
commit
f02708c2be
|
|
@ -764,12 +764,6 @@ else:
|
|||
"""
|
||||
slice_from_root = True
|
||||
if self != self._get_root_mesh():
|
||||
warnings.warn(
|
||||
"You are attempting to slice a submesh from another submesh. While we support this operation, "
|
||||
"it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. "
|
||||
"If not, this may result in some ranks receiving the submesh while others encounter errors.",
|
||||
stacklevel=2,
|
||||
)
|
||||
slice_from_root = False
|
||||
|
||||
# The slice mesh_dim_names should consist either the current device_mesh's mesh_dim_names
|
||||
|
|
|
|||
|
|
@ -832,9 +832,7 @@ class FSDPParam:
|
|||
if mesh.mesh_dim_names is None:
|
||||
raise AssertionError("Expected mesh_dim_names to not be None")
|
||||
shard_dim_name = mesh.mesh_dim_names[-1]
|
||||
|
||||
root_mesh = mesh._get_root_mesh()
|
||||
return root_mesh[shard_dim_name]
|
||||
return mesh[shard_dim_name]
|
||||
|
||||
def _assert_in_states(self, *states: ShardedState) -> None:
|
||||
if self.sharded_state not in states:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user