Commit Graph

9 Commits

Author SHA1 Message Date
Luca Wehrstedt
0d4c2b71e8 [DeviceMesh] Simplify unflatten method (#165556)
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
2025-10-17 17:57:51 +00:00
Luca Wehrstedt
d659bbde62 [DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)
The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
2025-10-17 17:57:51 +00:00
Luca Wehrstedt
58879bfafa [DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
2025-10-17 17:57:51 +00:00
PyTorch MergeBot
431c13cf61 Revert "[DeviceMesh] Simplify unflatten method (#165556)"
This reverts commit 86fd4fc23e.

Reverted https://github.com/pytorch/pytorch/pull/165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1 ([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
2025-10-16 20:41:37 +00:00
Luca Wehrstedt
86fd4fc23e [DeviceMesh] Simplify unflatten method (#165556)
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
2025-10-16 18:36:16 +00:00
Yuanyuan Chen
da003d7b95 [3/N] Import Callable from collections.abc in torch/distributed (#164104)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
This PR is the follow-up of #164054.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104
Approved by: https://github.com/Skylion007
2025-09-30 00:28:53 +00:00
fduwjj
232dd65c15 [CuTe] Change the logic of pycute manipulation ops like coalesce, complement from co-lex to lex (#162690)
PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162690
Approved by: https://github.com/ezyang
ghstack dependencies: #162413, #162534, #162414
2025-09-16 19:53:45 +00:00
fduwjj
561430edcd [CuTe] Add type for CuTe layout via claude (#162534)
This PR mostly is a cosmetic change using Claude to add types for copied PyCute code.
We removed all suppressions of linters and add type checker, type alias and mypy ignore(if needed) so that the pycute code will be checked by linter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162534
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #162413
2025-09-12 04:59:21 +00:00
fduwjj
5dd14f0b65 [CuTe] Copy code from pycute for device mesh bookkeeping (#162413)
We copied the whole module and its unit test into pytorch codebase. (https://github.com/NVIDIA/cutlass/blob/main/python%2Fpycute%2Flayout.py).

We did change the indentation of code from 2 spaces to 4 spaces. And add lint suppressor to make mypy happy.

Also we need to make changes to unit test to include ownership and use `run_tests, TestCase` so that the test gets picked up by CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162413
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2025-09-12 04:28:03 +00:00