Commit Graph

6 Commits

Author SHA1 Message Date
fduwjj
b0985144b5 [DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)
We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.

Concretely, in this PR, we do the following:
1. Use the `_MeshLayout` to handle all index operations rather use a map to record mesh dims.
2. Removed `flatten_name_to_root_dims`, because now we can directly get layout from a flattened device mesh.
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use the newly added function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1).

Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match.

<img width="1113" height="490" alt="image" src="https://github.com/user-attachments/assets/7877b5a4-337e-4ad8-b878-2378f4f0f38d" />

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of https://github.com/pytorch/pytorch/pull/161106 (original one got messed with EasyCLA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163213
Approved by: https://github.com/lw, https://github.com/fegin
2025-10-02 15:42:03 +00:00
fduwjj
b48a3d0a38 [CuTe] Add layout overlap checking util function in _MeshLayout (#163367)
While refactoring the bookkeeping for DeviceMesh while leveraging CuTe layout, we found that we need to have two more util functions. One is to check whether one layout has overlap inside it or not. For example, (2,2):(2:1) has no overlap while (2,2):(2:2) has overlap.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163367
Approved by: https://github.com/fegin
ghstack dependencies: #163212, #163288, #163928, #163930
2025-09-27 00:22:14 +00:00
fduwjj
f1f2e3e4da [DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping (#163212)
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to @lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243.

On top of what we have been doing about PyCute we want to continue add methods into the wrapper class so that we can get rank indexes needed for ProcessGroup Creation with a layout object. We also added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected.

More PRs are on the way.

This is a continue of https://github.com/pytorch/pytorch/pull/161016 (originally messed with EasyCLA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163212
Approved by: https://github.com/ezyang, https://github.com/fegin, https://github.com/lw
2025-09-26 03:32:19 +00:00
fduwjj
19a4ef0256 [DeviceMesh] Make CuTe layout as mesh layout to be ready for using in DeviceMesh (#162414)
We create a wrapper class named "_MeshLayout" acting as a layout for device mesh so that we can add new methods more specific to DeviceMesh and keep the core logic of CuTe manipulation inside pycute module. This PR create the main body of the code and then next PR will come with actual implementation and unit test for device mesh layout. (Actual implementation can be found in https://github.com/pytorch/pytorch/pull/161016)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162414
Approved by: https://github.com/ezyang, https://github.com/fegin
ghstack dependencies: #162413, #162534
2025-09-15 17:04:41 +00:00
PyTorch MergeBot
b5f4a7dc14 Revert "[DeviceMesh] Make CuTe layout as mesh layout to be ready for using in DeviceMesh (#162414)"
This reverts commit 195ac549d7.

Reverted https://github.com/pytorch/pytorch/pull/162414 on behalf of https://github.com/malfet due to Looks like it broke test_circular_deps on Windows, see d89189f289/1 ([comment](https://github.com/pytorch/pytorch/pull/162414#issuecomment-3286070938))
2025-09-12 16:57:09 +00:00
fduwjj
195ac549d7 [DeviceMesh] Make CuTe layout as mesh layout to be ready for using in DeviceMesh (#162414)
We create a wrapper class acting as a layout for device mesh so that we can add new methods more specific to DeviceMesh and keep the core logic of CuTe manipulation inside pycute module. This PR create the main body of the code and then next PR will come with actual implementation and unit test for device mesh layout. (Actual implementation can be found in https://github.com/pytorch/pytorch/pull/161016)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162414
Approved by: https://github.com/ezyang
ghstack dependencies: #162413, #162534
2025-09-12 07:32:56 +00:00