diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 8606dc20593..dfd07c707a7 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -136,6 +136,10 @@ else: # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now. mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) for idx in slice_dim_idx: + if idx not in mesh_dims_remained_idx: + raise NotImplementedError( + "Currently, this only allows slicing out a contiguous flattened dim." + ) mesh_dims_remained_idx.remove(idx) # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]