From 63a96eaeb84f5af6e83c7360c09f5540a44d19ca Mon Sep 17 00:00:00 2001 From: fduwjj Date: Wed, 2 Jul 2025 20:47:38 -0700 Subject: [PATCH] [DeviceMesh] Add error when users try to slice non contiguous flattened dim submesh (#157523) With https://github.com/pytorch/pytorch/issues/157393, we want to first throw a clearer error for users and then fix it in the long-term Pull Request resolved: https://github.com/pytorch/pytorch/pull/157523 Approved by: https://github.com/fegin ghstack dependencies: #157501 --- torch/distributed/device_mesh.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 8606dc20593..dfd07c707a7 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -136,6 +136,10 @@ else: # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now. mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) for idx in slice_dim_idx: + if idx not in mesh_dims_remained_idx: + raise NotImplementedError( + "Currently, this only allows slicing out a contiguous flattened dim." + ) mesh_dims_remained_idx.remove(idx) # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]