mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DeviceMesh] Add error when users try to slice non contiguous flattened dim submesh (#157523)
With https://github.com/pytorch/pytorch/issues/157393, we want to first throw a clearer error for users and then fix it in the long-term Pull Request resolved: https://github.com/pytorch/pytorch/pull/157523 Approved by: https://github.com/fegin ghstack dependencies: #157501
This commit is contained in:
parent
2b8d3b1b2b
commit
63a96eaeb8
|
|
@ -136,6 +136,10 @@ else:
|
|||
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
|
||||
mesh_dims_remained_idx = list(range(mesh_tensor.ndim))
|
||||
for idx in slice_dim_idx:
|
||||
if idx not in mesh_dims_remained_idx:
|
||||
raise NotImplementedError(
|
||||
"Currently, this only allows slicing out a contiguous flattened dim."
|
||||
)
|
||||
mesh_dims_remained_idx.remove(idx)
|
||||
|
||||
# pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user