Since we are already share a flattened tensor `_rank_map` across all meshes from a same root mesh, we can just use a flattened list of it to replace the comparison of root_mesh and flattened_mesh_list (because with same _rank_map and layout, the mesh tensor is guaranteed to be the same). This way we can also give back the CPU overhead added in https://github.com/pytorch/pytorch/pull/164510 and further simply the code.
We do have a more ambitious universe-based change here: https://github.com/pytorch/pytorch/pull/165680 but it needs more discussions and would lead to BC breaking. We might eventually merge that PR but probably not now and this is a change which is not BC breaking and will help concatenate and 2D integration with concatenate.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166003
Approved by: https://github.com/Skylion007, https://github.com/fegin
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.
In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.
This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.
With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.
In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.
This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.
With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
The `_flatten_mapping` field was defined as a class attribute with a mutable default value {}:
```
_flatten_mapping: dict[str, "DeviceMesh"] = {}
```
This caused all DeviceMesh instances to share the same dictionary object. When multiple test instances tried to create flattened meshes with the same name (like "dp"), they would conflict because they were all using the same shared dictionary, resulting in the error: "Flatten mesh with mesh_dim_name dp has been created before, Please specify another valid mesh_dim_name."
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165521
Approved by: https://github.com/fegin, https://github.com/lw
This is mostly mechanical change which make device mesh members all private and use a public property API instead. This is not a BC breaking change since the new API still guarantee BC.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164954
Approved by: https://github.com/fegin
ghstack dependencies: #164750
We allow passing in PG option via https://github.com/pytorch/pytorch/pull/159371 and we did a clean up of Meta internal usage of `_set_mesh_dim_group_options`, since this a private API, we don't have any bc guarantee, we want to directly remove so that people use the new behavior from now on.
Also since we now allow passing pg in both DeviceMesh constructor and flatten API, so that we also want to get rid of the global pg option override variable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164750
Approved by: https://github.com/lw, https://github.com/fegin
We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.
Concretely, in this PR, we do the following:
1. Use the `_MeshLayout` to handle all index operations rather use a map to record mesh dims.
2. Removed `flatten_name_to_root_dims`, because now we can directly get layout from a flattened device mesh.
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use the newly added function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1).
Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match.
<img width="1113" height="490" alt="image" src="https://github.com/user-attachments/assets/7877b5a4-337e-4ad8-b878-2378f4f0f38d" />
The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.
With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.
This is a continue of https://github.com/pytorch/pytorch/pull/161106 (original one got messed with EasyCLA)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163213
Approved by: https://github.com/lw, https://github.com/fegin
We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.
Concretely, in this PR, we do the following:
1. Use the `_MeshLayout` to handle all index operations rather use a map to record mesh dims.
2. Removed `flatten_name_to_root_dims`, because now we can directly get layout from a flattened device mesh.
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use the newly added function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1).
Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match.
<img width="1113" height="490" alt="image" src="https://github.com/user-attachments/assets/7877b5a4-337e-4ad8-b878-2378f4f0f38d" />
The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.
With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.
This is a continue of https://github.com/pytorch/pytorch/pull/161106 (original one got messed with EasyCLA)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163213
Approved by: https://github.com/lw, https://github.com/fegin
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in https://github.com/pytorch/pytorch/pull/160709 but it does not fixed the check for the case when we call the `_flatten` twice.
What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that.
Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161311
Approved by: https://github.com/fegin
Summary: This adds a `_rank` field to DeviceMesh init that allows for instantiating a DeviceMesh without depending on `dist.get_rank()` which requires a global PG to be instantiated.
Test Plan:
```
buck2 test mode/opt -c fbcode.enable_gpu_sections=true //caffe2/test/distributed:device_mesh -- init_backend
```
Rollback Plan:
Differential Revision: D81981777
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162439
Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This PR is greatly simplified now that it stacked on top of a PR that builds with distributed always. We only need to stub functions that may not be defined due to a backend not being enabled.
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159889
Approved by: https://github.com/wconstab
ghstack dependencies: #160449
Fix the `DeviceMesh._flatten` docstring example of use. Alternative fix would be to replace `mesh_3d["dp", "cp"]` with `mesh_3d["cp", "tp"]`.
(I verified the fix using the `gloo` backend)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162277
Approved by: https://github.com/ezyang
This PR is greatly simplified now that it stacked on top of a PR that builds with distributed always. We only need to stub functions that may not be defined due to a backend not being enabled.
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159889
Approved by: https://github.com/wconstab
ghstack dependencies: #160449
This PR is greatly simplified now that it stacked on top of a PR that builds with distributed always. We only need to stub functions that may not be defined due to a backend not being enabled.
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159889
Approved by: https://github.com/wconstab
ghstack dependencies: #160449