From 19a4ef025673cba3d682b97661942992bc6ed261 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Fri, 12 Sep 2025 13:49:30 -0700 Subject: [PATCH] [DeviceMesh] Make CuTe layout as mesh layout to be ready for using in DeviceMesh (#162414) We create a wrapper class named "_MeshLayout" acting as a layout for device mesh so that we can add new methods more specific to DeviceMesh and keep the core logic of CuTe manipulation inside pycute module. This PR create the main body of the code and then next PR will come with actual implementation and unit test for device mesh layout. (Actual implementation can be found in https://github.com/pytorch/pytorch/pull/161016) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162414 Approved by: https://github.com/ezyang, https://github.com/fegin ghstack dependencies: #162413, #162534 --- torch/distributed/_mesh_layout.py | 71 +++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 torch/distributed/_mesh_layout.py diff --git a/torch/distributed/_mesh_layout.py b/torch/distributed/_mesh_layout.py new file mode 100644 index 00000000000..86969fccc55 --- /dev/null +++ b/torch/distributed/_mesh_layout.py @@ -0,0 +1,71 @@ +""" +Definition of CuTe inspired Layouts for DeviceMesh internal bookkeeping and functions to manipulate them +""" + +import math +from collections.abc import Iterator +from dataclasses import dataclass + +from torch.distributed._pycute import ( + coalesce, + complement, + composition, + flatten, + IntTuple, + is_int, + is_tuple, + Layout, +) + + +@dataclass(frozen=True, init=True) +class _MeshLayout(Layout): + shape: IntTuple + stride: IntTuple + + def __post_init__(self) -> None: + if not is_tuple(self.shape) and not is_int(self.shape): + raise TypeError(f"shape must be a tuple or int, got {type(self.shape)}") + if not is_tuple(self.stride) and not is_int(self.stride): + raise TypeError(f"stride must be a tuple or int, got {type(self.stride)}") + if ( + is_tuple(self.shape) + and is_tuple(self.stride) + and len(flatten(self.shape)) != len(flatten(self.stride)) + ): + raise ValueError( + f"sizes {len(flatten(self.shape))} and " + f"strides {len(flatten(self.stride))} must have the same length" + ) + + @property + def sizes(self) -> IntTuple: + return self.shape + + @property + def strides(self) -> IntTuple: + return self.stride + + @property + def sizes_and_strides(self) -> Iterator[tuple[int, int]]: + return zip(flatten(self.shape), flatten(self.stride)) + + def numel(self) -> int: + return math.prod(flatten(self.shape)) + + # # operator [] (get-i like tuples) + def __getitem__(self, i: int) -> "_MeshLayout": + layout = super().__getitem__(i) + return _MeshLayout(layout.shape, layout.stride) + + def coalesce(self) -> "_MeshLayout": + layout = coalesce(self) + return _MeshLayout(layout.shape, layout.stride) + + def composition(self, layout: "_MeshLayout") -> "_MeshLayout": + result = composition(self, layout) + return _MeshLayout(result.shape, result.stride) + + def complement(self, world_size: int) -> "_MeshLayout": + layout = complement(self, world_size) + return _MeshLayout(layout.shape, layout.stride)