[CuTe] Change the logic of pycute manipulation ops like coalesce, complement from co-lex to lex (#162690)

PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162690
Approved by: https://github.com/ezyang
ghstack dependencies: #162413, #162534, #162414
This commit is contained in:
fduwjj 2025-09-12 14:34:24 -07:00 committed by PyTorch MergeBot
parent 505ee42570
commit 232dd65c15
6 changed files with 237 additions and 33 deletions

View File

@ -47,11 +47,14 @@ _LOGGER = logging.getLogger(__name__)
class TestCoalesce(TestCase):
def helper_test_coalesce(self, layout):
def helper_test_coalesce(self, layout, coalesced_layout=None):
layoutR = coalesce(layout)
_LOGGER.debug(f"{layout} => {layoutR}")
if coalesced_layout:
self.assertEqual(coalesced_layout.shape, layoutR.shape)
self.assertEqual(coalesced_layout.stride, layoutR.stride)
self.assertEqual(size(layoutR), size(layout))
for i in range(size(layout)):
@ -82,11 +85,17 @@ class TestCoalesce(TestCase):
layout = Layout((2, (4, 6)))
self.helper_test_coalesce(layout)
layout = Layout((1, 2), (8, 1))
coalesced_layout = Layout(2, 1)
self.helper_test_coalesce(layout, coalesced_layout)
layout = Layout((2, 4), (4, 1))
self.helper_test_coalesce(layout)
coalesced_layout = Layout(8, 1)
self.helper_test_coalesce(layout, coalesced_layout)
layout = Layout((2, 4, 6), (24, 6, 1))
self.helper_test_coalesce(layout)
coalesced_layout = Layout(48, 1)
self.helper_test_coalesce(layout, coalesced_layout)
layout = Layout((2, 1, 3), (2, 4, 4))
self.helper_test_coalesce(layout)
@ -94,6 +103,10 @@ class TestCoalesce(TestCase):
layout = Layout(((2, 2), (2, 2)), ((1, 4), (8, 32)))
self.helper_test_coalesce(layout)
layout = Layout(((2, 2), (2, 2)), ((32, 8), (4, 1)))
coalesced_layout = Layout((2, 4, 2), (32, 4, 1))
self.helper_test_coalesce(layout, coalesced_layout)
if __name__ == "__main__":
run_tests()

View File

@ -208,11 +208,26 @@ class TestComposition(TestCase):
layoutB = Layout((6), (1))
self.helper_test_composition(layoutA, layoutB)
# Pre-coalesced RHS
layoutA = Layout((8, 6, 4), (7, 4, 1))
layoutB = Layout((6), (1))
self.helper_test_composition(layoutA, layoutB)
# Case when not meet stride divisibility condition
with self.assertRaises(AssertionError):
layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7))
layoutB = Layout(6, 12)
self.helper_test_composition(layoutA, layoutB)
# Mid-layout truncation
layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7))
layoutA = Layout((10, 8, 6, 4), (7, 5, 3, 2))
layoutB = Layout(6, 12)
self.helper_test_composition(layoutA, layoutB)
layoutA = Layout((4,), (3,))
layoutB = Layout((6,), (2,))
self.helper_test_composition(layoutA, layoutB)
if __name__ == "__main__":
run_tests()

View File

@ -67,20 +67,159 @@ class TestIntTuple(TestCase):
self.assertEqual(shape_div((6, (3, 4)), 36), (1, (1, 2)))
def test_prefix_product(self):
self.assertEqual(prefix_product(2), 1)
def test_suffix_product(self):
self.assertEqual(suffix_product(2), 1)
self.assertEqual(prefix_product((3, 2)), (1, 3))
self.assertEqual(suffix_product((3, 2)), (2, 1))
self.assertEqual(prefix_product((3, 2, 4)), (1, 3, 6))
self.assertEqual(suffix_product((3, 2, 4)), (8, 4, 1))
self.assertEqual(prefix_product(((2, 3), 4)), ((1, 2), 6))
self.assertEqual(suffix_product(((2, 3), 4)), ((12, 4), 1))
self.assertEqual(
prefix_product(((2, 3), (2, 1, 2), (5, 2, 1))),
((1, 2), (6, 12, 12), (24, 120, 240)),
suffix_product(((2, 3), (2, 1, 2), (5, 2, 1))),
((120, 40), (20, 20, 10), (2, 1, 1)),
)
def test_crd2idx_basic(self):
# Test basic int/int case
self.assertEqual(crd2idx(2, 5, 1), 2)
self.assertEqual(crd2idx(0, 5, 1), 0)
self.assertEqual(crd2idx(4, 5, 1), 4)
# Test with custom stride
self.assertEqual(crd2idx(2, 5, 3), 6)
self.assertEqual(crd2idx(1, 5, 3), 3)
def test_crd2idx_tuple(self):
# Test tuple coordinates with default stride
self.assertEqual(crd2idx((1, 2), (3, 4)), 6) # 1*4 + 2*1 = 6
self.assertEqual(crd2idx((0, 0), (3, 4)), 0)
self.assertEqual(crd2idx((2, 3), (3, 4)), 11) # 2*4 + 3*1 = 11
# Test with custom stride
self.assertEqual(crd2idx((1, 2), (3, 4), (8, 2)), 12) # 1*8 + 2*2 = 12
# Test 3D case
self.assertEqual(crd2idx((1, 0, 2), (2, 3, 4)), 14) # 1*12 + 0*4 + 2*1 = 14
def test_crd2idx_none(self):
# Test None coordinate (should default to 0)
self.assertEqual(crd2idx(None, 5), 0)
self.assertEqual(crd2idx(None, (3, 4)), 0)
def test_crd2idx_int_with_tuple_shape(self):
# Test single integer coordinate with multi-dimensional shape and stride
# When crd is int and shape is tuple, it converts the int to multi-dim coordinate first
self.assertEqual(crd2idx(0, (2, 2), (2, 1)), 0) # 0 -> (0,0) -> 0*2 + 0*1 = 0
self.assertEqual(crd2idx(1, (2, 2), (2, 1)), 1) # 1 -> (0,1) -> 0*2 + 1*1 = 1
self.assertEqual(crd2idx(2, (2, 2), (2, 1)), 2) # 2 -> (1,0) -> 1*2 + 0*1 = 2
self.assertEqual(crd2idx(3, (2, 2), (2, 1)), 3) # 3 -> (1,1) -> 1*2 + 1*1 = 3
# Test with non-trivial strides
self.assertEqual(crd2idx(0, (2, 3), (6, 2)), 0) # 0 -> (0,0) -> 0*6 + 0*2 = 0
self.assertEqual(crd2idx(1, (2, 3), (6, 2)), 2) # 1 -> (0,1) -> 0*6 + 1*2 = 2
self.assertEqual(crd2idx(2, (2, 3), (6, 2)), 4) # 2 -> (0,2) -> 0*6 + 2*2 = 4
self.assertEqual(crd2idx(3, (2, 3), (6, 2)), 6) # 3 -> (1,0) -> 1*6 + 0*2 = 6
self.assertEqual(crd2idx(4, (2, 3), (6, 2)), 8) # 4 -> (1,1) -> 1*6 + 1*2 = 8
self.assertEqual(crd2idx(5, (2, 3), (6, 2)), 10) # 5 -> (1,2) -> 1*6 + 2*2 = 10
# Test with larger strides
self.assertEqual(crd2idx(0, (3, 2), (10, 5)), 0) # 0 -> (0,0) -> 0*10 + 0*5 = 0
self.assertEqual(crd2idx(1, (3, 2), (10, 5)), 5) # 1 -> (0,1) -> 0*10 + 1*5 = 5
self.assertEqual(
crd2idx(2, (3, 2), (10, 5)), 10
) # 2 -> (1,0) -> 1*10 + 0*5 = 10
self.assertEqual(
crd2idx(3, (3, 2), (10, 5)), 15
) # 3 -> (1,1) -> 1*10 + 1*5 = 15
self.assertEqual(
crd2idx(4, (3, 2), (10, 5)), 20
) # 4 -> (2,0) -> 2*10 + 0*5 = 20
self.assertEqual(
crd2idx(5, (3, 2), (10, 5)), 25
) # 5 -> (2,1) -> 2*10 + 1*5 = 25
# Test with 3D shape and various strides
self.assertEqual(
crd2idx(0, (2, 2, 2), (8, 4, 2)), 0
) # 0 -> (0,0,0) -> 0*8 + 0*4 + 0*2 = 0
self.assertEqual(
crd2idx(1, (2, 2, 2), (8, 4, 2)), 2
) # 1 -> (0,0,1) -> 0*8 + 0*4 + 1*2 = 2
self.assertEqual(
crd2idx(2, (2, 2, 2), (8, 4, 2)), 4
) # 2 -> (0,1,0) -> 0*8 + 1*4 + 0*2 = 4
self.assertEqual(
crd2idx(3, (2, 2, 2), (8, 4, 2)), 6
) # 3 -> (0,1,1) -> 0*8 + 1*4 + 1*2 = 6
self.assertEqual(
crd2idx(4, (2, 2, 2), (8, 4, 2)), 8
) # 4 -> (1,0,0) -> 1*8 + 0*4 + 0*2 = 8
self.assertEqual(
crd2idx(7, (2, 2, 2), (8, 4, 2)), 14
) # 7 -> (1,1,1) -> 1*8 + 1*4 + 1*2 = 14
self.assertEqual(
crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
) # 4 -> (1,0,0) -> 1*8 = 8
def test_idx2crd_basic(self):
# Test basic int/int case
self.assertEqual(idx2crd(2, 5, 1), 2)
self.assertEqual(idx2crd(0, 5, 1), 0)
self.assertEqual(idx2crd(4, 5, 1), 4)
# Test with custom stride
self.assertEqual(idx2crd(6, 5, 3), 2) # (6 // 3) % 5 = 2
self.assertEqual(idx2crd(3, 5, 3), 1) # (3 // 3) % 5 = 1
def test_idx2crd_tuple(self):
# Test tuple shape with default stride
self.assertEqual(idx2crd(6, (3, 4)), (1, 2)) # 6 -> (1, 2)
self.assertEqual(idx2crd(0, (3, 4)), (0, 0))
self.assertEqual(idx2crd(11, (3, 4)), (2, 3))
# Test 3D case
self.assertEqual(idx2crd(14, (2, 3, 4)), (1, 0, 2))
def test_crd2idx_idx2crd_roundtrip(self):
# Test that crd2idx and idx2crd are inverse operations
shapes = [
5,
(3, 4),
(2, 3, 4),
(2, 2, 2, 2),
]
for shape in shapes:
size = product(shape)
for idx in range(size):
crd = idx2crd(idx, shape)
recovered_idx = crd2idx(crd, shape)
self.assertEqual(
recovered_idx, idx, f"Failed roundtrip for shape {shape}, idx {idx}"
)
def test_idx2crd_crd2idx_roundtrip(self):
# Test roundtrip starting from coordinates
test_cases = [
(0, 5),
(4, 5),
((0, 0), (3, 4)),
((1, 2), (3, 4)),
((2, 3), (3, 4)),
((0, 0, 0), (2, 3, 4)),
((1, 2, 3), (2, 3, 4)),
]
for crd, shape in test_cases:
idx = crd2idx(crd, shape)
recovered_crd = idx2crd(idx, shape)
self.assertEqual(
recovered_crd, crd, f"Failed roundtrip for crd {crd}, shape {shape}"
)
if __name__ == "__main__":
run_tests()

View File

@ -41,11 +41,11 @@ from .int_tuple import (
IntTuple,
is_int,
is_tuple,
prefix_product,
product,
shape_div,
signum,
slice_,
suffix_product,
tuple_max,
)
from .layout import (

View File

@ -126,18 +126,26 @@ def shape_div(a: IntTuple, b: IntTuple) -> IntTuple:
return (a + b - 1) // b
# Exclusive prefix product with output congruent to input a
def prefix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple:
# Exclusive suffix product with output congruent to input a (lexicographic)
def suffix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple:
# TODO: With all these length asserts, may want to create a zip_strict wrapper.
if is_tuple(a):
if is_tuple(init): # tuple tuple
assert len(a) == len(init)
return tuple(prefix_product(x, i) for x, i in zip(a, init))
return tuple(suffix_product(x, i) for x, i in zip(a, init))
else: # tuple "int"
# r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))]
# Process from right to left for lexicographic ordering
# r = [prefix_product(a[len(a)-1],init)] +
# [prefix_product(a[i],init := init * product(a[i+1])) for i in range(len(a)-1,0)].reverse()
r = []
for v in a:
r.append(prefix_product(v, init))
init = init * product(v)
# Calculate products from right to left, appending to list
for i in range(len(a) - 1, -1, -1):
r.append(suffix_product(a[i], init))
init = init * product(a[i])
# Reverse to get correct lexicographic order
r.reverse()
return tuple(r)
else:
if is_tuple(init): # "int" tuple
@ -150,7 +158,7 @@ def idx2crd(
idx: IntTuple, shape: IntTuple, stride: Optional[IntTuple] = None
) -> IntTuple:
if stride is None:
stride = prefix_product(shape)
stride = suffix_product(shape)
if is_tuple(idx):
if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple
@ -171,7 +179,7 @@ def crd2idx(
crd: Optional[IntTuple], shape: IntTuple, stride: Optional[IntTuple] = None
) -> int:
if stride is None:
stride = prefix_product(shape)
stride = suffix_product(shape)
if is_tuple(crd):
if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple
@ -186,10 +194,11 @@ def crd2idx(
if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple
assert len(shape) == len(stride)
result = 0
for i in range(len(shape) - 1):
# Process from right to left for lexicographic ordering
for i in range(len(shape) - 1, 0, -1):
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
crd = crd // product(shape[i])
return result + crd2idx(crd, shape[-1], stride[-1])
return result + crd2idx(crd, shape[0], stride[0])
else: # "int" "int" "int"
assert not is_tuple(shape) and not is_tuple(stride)
return crd * stride # all are ints after type checks

View File

@ -31,7 +31,8 @@
#################################################################################################
"""
Definition of CuTe Layouts and functions to manipulate them
Definition of CuTe Layouts and functions to manipulate them which works with the order
of lexicographic instead of co-lexicographic as implemented in the original layout.py
"""
from itertools import chain
@ -45,9 +46,9 @@ from .int_tuple import (
IntTuple,
is_int,
is_tuple,
prefix_product,
product,
slice_,
suffix_product,
)
@ -72,7 +73,7 @@ class Layout(LayoutBase):
def __init__(self, _shape: IntTuple, _stride: Optional[IntTuple] = None) -> None:
self.shape = _shape
if _stride is None:
self.stride = prefix_product(self.shape)
self.stride = suffix_product(self.shape)
else:
self.stride = _stride
@ -168,7 +169,11 @@ def coalesce(layout: Layout, profile: LayoutProfile = None) -> Layout:
result_shape = [1]
result_stride = [0]
for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)):
# Since we now follow lexicographic order, we need to process from right to left.
# And to make implementation more efficient, we append to the end of list and reverse it in the end.
for shape, stride in zip(
reversed(flatten(layout.shape)), reversed(flatten(layout.stride))
):
# skip their shape-1s
if shape == 1:
continue
@ -187,6 +192,8 @@ def coalesce(layout: Layout, profile: LayoutProfile = None) -> Layout:
if len(result_shape) == 1:
return Layout(result_shape[0], result_stride[0])
else:
result_shape.reverse()
result_stride.reverse()
return Layout(tuple(result_shape), tuple(result_stride))
@ -241,14 +248,22 @@ def composition(layoutA: Layout, layoutB: LayoutInput) -> Layout:
rest_shape = layoutB.shape
rest_stride = layoutB.stride
flat_A = coalesce(layoutA)
# when left layout is multi-dimensional sublayout, aka, self = (a,b,...,c):(x,y,...,z), layout = s:d,
# for integral s and d means that we want:
# (1) “remove” the first d elements from left, starting from rightmost. (This will increase the stride.)
# (2) “keep” the first s of those strided elements. (This does not affect the stride.)
# For example, if self = (6,2):(2,1), layout = (3:2)
# Step 1: remove the first 2 elements from self with stride increase, i.e., (6,2):(2,1) -> (6,1):(2,2)
# Step 2: keep the first 3 of those strided elements, i.e., (6,1):(2,2) -> (3,1):(2,2)
# Because we are going lexicographically, we go through left layout from right to left.
for curr_shape, curr_stride in zip(
flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]
reversed(flatten(flat_A.shape)[1:]), reversed(flatten(flat_A.stride)[1:])
):
assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 # type: ignore[operator]
new_shape = min(max(1, curr_shape // rest_stride), rest_shape) # type: ignore[operator]
if new_shape != 1:
result_shape.append(new_shape)
result_shape.append(new_shape) # Append to end, will reverse later
result_stride.append(rest_stride * curr_stride)
rest_shape = rest_shape // new_shape # type: ignore[operator]
@ -256,9 +271,16 @@ def composition(layoutA: Layout, layoutB: LayoutInput) -> Layout:
-rest_stride // curr_shape # type: ignore[operator]
) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride)
# When left has single-size sublayout or reach the last sublayout, aka, left = a:b, layout = s:d,
# the result is rather trivial: left o layout = a:b o s:d = s:(b*d).
# For example, if self = (6:2), layout = (3:2), the result is (3:(2*2)) = (3:4).
if rest_shape != 1 or len(result_shape) == 0:
result_shape.append(rest_shape)
result_stride.append(rest_stride * flatten(flat_A.stride)[-1])
result_shape.append(rest_shape) # Append to end, will reverse later
result_stride.append(rest_stride * flatten(flat_A.stride)[0])
# Reverse the lists because we build lists in reverse order (append to end), this way it is more efficient.
result_shape.reverse()
result_stride.reverse()
if len(result_shape) == 1:
return Layout(result_shape[0], result_stride[0]) # type: ignore[arg-type]
@ -290,6 +312,10 @@ def complement(layout: LayoutOrIntTuple, max_idx: int = 1) -> Layout:
result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
result_stride.append(current_idx)
# This is different from original pycute implementation, because we want to follow the lexicographic order here
# where the right-most dimension is the innermost dimension (smallest stride).
result_shape.reverse()
result_stride.reverse()
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
@ -307,7 +333,7 @@ def right_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]:
flat_shape = flatten(layout.shape) # type: ignore[union-attr]
flat_stride = flatten(layout.stride) # type: ignore[union-attr]
sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape))) # type: ignore[arg-type]
sorted_DSA = sorted(zip(flat_stride, flat_shape, suffix_product(flat_shape))) # type: ignore[arg-type]
for stride, shape, rstride in sorted_DSA:
if shape == 1:
continue
@ -318,6 +344,8 @@ def right_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]:
result_stride.append(rstride)
current_idx = shape * stride
result_shape.reverse()
result_stride.reverse()
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
@ -327,7 +355,7 @@ def left_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]:
return None
elif is_int(layout):
return Layout(layout)
return right_inverse(make_layout(layout, complement(layout))) # type: ignore[arg-type]
return right_inverse(make_layout(complement(layout), layout)) # type: ignore[arg-type]
# Split a layout by the composition of B and the "rest"