[DeviceMesh] Add error when users try to slice non contiguous flattened dim submesh (#157523)

With https://github.com/pytorch/pytorch/issues/157393, we want to first throw a clearer error for users and then fix it in the long-term

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157523
Approved by: https://github.com/fegin
ghstack dependencies: #157501
This commit is contained in:
fduwjj 2025-07-02 20:47:38 -07:00 committed by PyTorch MergeBot
parent 2b8d3b1b2b
commit 63a96eaeb8

View File

@ -136,6 +136,10 @@ else:
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
mesh_dims_remained_idx = list(range(mesh_tensor.ndim))
for idx in slice_dim_idx:
if idx not in mesh_dims_remained_idx:
raise NotImplementedError(
"Currently, this only allows slicing out a contiguous flattened dim."
)
mesh_dims_remained_idx.remove(idx)
# pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]