mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CuTe] Change the logic of pycute manipulation ops like coalesce, complement from co-lex to lex (#162690)
PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.
However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.
Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162690
Approved by: https://github.com/ezyang
ghstack dependencies: #162413, #162534, #162414
This commit is contained in:
parent
505ee42570
commit
232dd65c15
|
|
@ -47,11 +47,14 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TestCoalesce(TestCase):
|
||||
def helper_test_coalesce(self, layout):
|
||||
def helper_test_coalesce(self, layout, coalesced_layout=None):
|
||||
layoutR = coalesce(layout)
|
||||
|
||||
_LOGGER.debug(f"{layout} => {layoutR}")
|
||||
|
||||
if coalesced_layout:
|
||||
self.assertEqual(coalesced_layout.shape, layoutR.shape)
|
||||
self.assertEqual(coalesced_layout.stride, layoutR.stride)
|
||||
self.assertEqual(size(layoutR), size(layout))
|
||||
|
||||
for i in range(size(layout)):
|
||||
|
|
@ -82,11 +85,17 @@ class TestCoalesce(TestCase):
|
|||
layout = Layout((2, (4, 6)))
|
||||
self.helper_test_coalesce(layout)
|
||||
|
||||
layout = Layout((1, 2), (8, 1))
|
||||
coalesced_layout = Layout(2, 1)
|
||||
self.helper_test_coalesce(layout, coalesced_layout)
|
||||
|
||||
layout = Layout((2, 4), (4, 1))
|
||||
self.helper_test_coalesce(layout)
|
||||
coalesced_layout = Layout(8, 1)
|
||||
self.helper_test_coalesce(layout, coalesced_layout)
|
||||
|
||||
layout = Layout((2, 4, 6), (24, 6, 1))
|
||||
self.helper_test_coalesce(layout)
|
||||
coalesced_layout = Layout(48, 1)
|
||||
self.helper_test_coalesce(layout, coalesced_layout)
|
||||
|
||||
layout = Layout((2, 1, 3), (2, 4, 4))
|
||||
self.helper_test_coalesce(layout)
|
||||
|
|
@ -94,6 +103,10 @@ class TestCoalesce(TestCase):
|
|||
layout = Layout(((2, 2), (2, 2)), ((1, 4), (8, 32)))
|
||||
self.helper_test_coalesce(layout)
|
||||
|
||||
layout = Layout(((2, 2), (2, 2)), ((32, 8), (4, 1)))
|
||||
coalesced_layout = Layout((2, 4, 2), (32, 4, 1))
|
||||
self.helper_test_coalesce(layout, coalesced_layout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -208,11 +208,26 @@ class TestComposition(TestCase):
|
|||
layoutB = Layout((6), (1))
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
# Pre-coalesced RHS
|
||||
layoutA = Layout((8, 6, 4), (7, 4, 1))
|
||||
layoutB = Layout((6), (1))
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
# Case when not meet stride divisibility condition
|
||||
with self.assertRaises(AssertionError):
|
||||
layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7))
|
||||
layoutB = Layout(6, 12)
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
# Mid-layout truncation
|
||||
layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7))
|
||||
layoutA = Layout((10, 8, 6, 4), (7, 5, 3, 2))
|
||||
layoutB = Layout(6, 12)
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
layoutA = Layout((4,), (3,))
|
||||
layoutB = Layout((6,), (2,))
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -67,20 +67,159 @@ class TestIntTuple(TestCase):
|
|||
|
||||
self.assertEqual(shape_div((6, (3, 4)), 36), (1, (1, 2)))
|
||||
|
||||
def test_prefix_product(self):
|
||||
self.assertEqual(prefix_product(2), 1)
|
||||
def test_suffix_product(self):
|
||||
self.assertEqual(suffix_product(2), 1)
|
||||
|
||||
self.assertEqual(prefix_product((3, 2)), (1, 3))
|
||||
self.assertEqual(suffix_product((3, 2)), (2, 1))
|
||||
|
||||
self.assertEqual(prefix_product((3, 2, 4)), (1, 3, 6))
|
||||
self.assertEqual(suffix_product((3, 2, 4)), (8, 4, 1))
|
||||
|
||||
self.assertEqual(prefix_product(((2, 3), 4)), ((1, 2), 6))
|
||||
self.assertEqual(suffix_product(((2, 3), 4)), ((12, 4), 1))
|
||||
|
||||
self.assertEqual(
|
||||
prefix_product(((2, 3), (2, 1, 2), (5, 2, 1))),
|
||||
((1, 2), (6, 12, 12), (24, 120, 240)),
|
||||
suffix_product(((2, 3), (2, 1, 2), (5, 2, 1))),
|
||||
((120, 40), (20, 20, 10), (2, 1, 1)),
|
||||
)
|
||||
|
||||
def test_crd2idx_basic(self):
|
||||
# Test basic int/int case
|
||||
self.assertEqual(crd2idx(2, 5, 1), 2)
|
||||
self.assertEqual(crd2idx(0, 5, 1), 0)
|
||||
self.assertEqual(crd2idx(4, 5, 1), 4)
|
||||
|
||||
# Test with custom stride
|
||||
self.assertEqual(crd2idx(2, 5, 3), 6)
|
||||
self.assertEqual(crd2idx(1, 5, 3), 3)
|
||||
|
||||
def test_crd2idx_tuple(self):
|
||||
# Test tuple coordinates with default stride
|
||||
self.assertEqual(crd2idx((1, 2), (3, 4)), 6) # 1*4 + 2*1 = 6
|
||||
self.assertEqual(crd2idx((0, 0), (3, 4)), 0)
|
||||
self.assertEqual(crd2idx((2, 3), (3, 4)), 11) # 2*4 + 3*1 = 11
|
||||
|
||||
# Test with custom stride
|
||||
self.assertEqual(crd2idx((1, 2), (3, 4), (8, 2)), 12) # 1*8 + 2*2 = 12
|
||||
|
||||
# Test 3D case
|
||||
self.assertEqual(crd2idx((1, 0, 2), (2, 3, 4)), 14) # 1*12 + 0*4 + 2*1 = 14
|
||||
|
||||
def test_crd2idx_none(self):
|
||||
# Test None coordinate (should default to 0)
|
||||
self.assertEqual(crd2idx(None, 5), 0)
|
||||
self.assertEqual(crd2idx(None, (3, 4)), 0)
|
||||
|
||||
def test_crd2idx_int_with_tuple_shape(self):
|
||||
# Test single integer coordinate with multi-dimensional shape and stride
|
||||
# When crd is int and shape is tuple, it converts the int to multi-dim coordinate first
|
||||
self.assertEqual(crd2idx(0, (2, 2), (2, 1)), 0) # 0 -> (0,0) -> 0*2 + 0*1 = 0
|
||||
self.assertEqual(crd2idx(1, (2, 2), (2, 1)), 1) # 1 -> (0,1) -> 0*2 + 1*1 = 1
|
||||
self.assertEqual(crd2idx(2, (2, 2), (2, 1)), 2) # 2 -> (1,0) -> 1*2 + 0*1 = 2
|
||||
self.assertEqual(crd2idx(3, (2, 2), (2, 1)), 3) # 3 -> (1,1) -> 1*2 + 1*1 = 3
|
||||
|
||||
# Test with non-trivial strides
|
||||
self.assertEqual(crd2idx(0, (2, 3), (6, 2)), 0) # 0 -> (0,0) -> 0*6 + 0*2 = 0
|
||||
self.assertEqual(crd2idx(1, (2, 3), (6, 2)), 2) # 1 -> (0,1) -> 0*6 + 1*2 = 2
|
||||
self.assertEqual(crd2idx(2, (2, 3), (6, 2)), 4) # 2 -> (0,2) -> 0*6 + 2*2 = 4
|
||||
self.assertEqual(crd2idx(3, (2, 3), (6, 2)), 6) # 3 -> (1,0) -> 1*6 + 0*2 = 6
|
||||
self.assertEqual(crd2idx(4, (2, 3), (6, 2)), 8) # 4 -> (1,1) -> 1*6 + 1*2 = 8
|
||||
self.assertEqual(crd2idx(5, (2, 3), (6, 2)), 10) # 5 -> (1,2) -> 1*6 + 2*2 = 10
|
||||
|
||||
# Test with larger strides
|
||||
self.assertEqual(crd2idx(0, (3, 2), (10, 5)), 0) # 0 -> (0,0) -> 0*10 + 0*5 = 0
|
||||
self.assertEqual(crd2idx(1, (3, 2), (10, 5)), 5) # 1 -> (0,1) -> 0*10 + 1*5 = 5
|
||||
self.assertEqual(
|
||||
crd2idx(2, (3, 2), (10, 5)), 10
|
||||
) # 2 -> (1,0) -> 1*10 + 0*5 = 10
|
||||
self.assertEqual(
|
||||
crd2idx(3, (3, 2), (10, 5)), 15
|
||||
) # 3 -> (1,1) -> 1*10 + 1*5 = 15
|
||||
self.assertEqual(
|
||||
crd2idx(4, (3, 2), (10, 5)), 20
|
||||
) # 4 -> (2,0) -> 2*10 + 0*5 = 20
|
||||
self.assertEqual(
|
||||
crd2idx(5, (3, 2), (10, 5)), 25
|
||||
) # 5 -> (2,1) -> 2*10 + 1*5 = 25
|
||||
|
||||
# Test with 3D shape and various strides
|
||||
self.assertEqual(
|
||||
crd2idx(0, (2, 2, 2), (8, 4, 2)), 0
|
||||
) # 0 -> (0,0,0) -> 0*8 + 0*4 + 0*2 = 0
|
||||
self.assertEqual(
|
||||
crd2idx(1, (2, 2, 2), (8, 4, 2)), 2
|
||||
) # 1 -> (0,0,1) -> 0*8 + 0*4 + 1*2 = 2
|
||||
self.assertEqual(
|
||||
crd2idx(2, (2, 2, 2), (8, 4, 2)), 4
|
||||
) # 2 -> (0,1,0) -> 0*8 + 1*4 + 0*2 = 4
|
||||
self.assertEqual(
|
||||
crd2idx(3, (2, 2, 2), (8, 4, 2)), 6
|
||||
) # 3 -> (0,1,1) -> 0*8 + 1*4 + 1*2 = 6
|
||||
self.assertEqual(
|
||||
crd2idx(4, (2, 2, 2), (8, 4, 2)), 8
|
||||
) # 4 -> (1,0,0) -> 1*8 + 0*4 + 0*2 = 8
|
||||
self.assertEqual(
|
||||
crd2idx(7, (2, 2, 2), (8, 4, 2)), 14
|
||||
) # 7 -> (1,1,1) -> 1*8 + 1*4 + 1*2 = 14
|
||||
|
||||
self.assertEqual(
|
||||
crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
|
||||
) # 4 -> (1,0,0) -> 1*8 = 8
|
||||
|
||||
def test_idx2crd_basic(self):
|
||||
# Test basic int/int case
|
||||
self.assertEqual(idx2crd(2, 5, 1), 2)
|
||||
self.assertEqual(idx2crd(0, 5, 1), 0)
|
||||
self.assertEqual(idx2crd(4, 5, 1), 4)
|
||||
|
||||
# Test with custom stride
|
||||
self.assertEqual(idx2crd(6, 5, 3), 2) # (6 // 3) % 5 = 2
|
||||
self.assertEqual(idx2crd(3, 5, 3), 1) # (3 // 3) % 5 = 1
|
||||
|
||||
def test_idx2crd_tuple(self):
|
||||
# Test tuple shape with default stride
|
||||
self.assertEqual(idx2crd(6, (3, 4)), (1, 2)) # 6 -> (1, 2)
|
||||
self.assertEqual(idx2crd(0, (3, 4)), (0, 0))
|
||||
self.assertEqual(idx2crd(11, (3, 4)), (2, 3))
|
||||
|
||||
# Test 3D case
|
||||
self.assertEqual(idx2crd(14, (2, 3, 4)), (1, 0, 2))
|
||||
|
||||
def test_crd2idx_idx2crd_roundtrip(self):
|
||||
# Test that crd2idx and idx2crd are inverse operations
|
||||
shapes = [
|
||||
5,
|
||||
(3, 4),
|
||||
(2, 3, 4),
|
||||
(2, 2, 2, 2),
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
size = product(shape)
|
||||
for idx in range(size):
|
||||
crd = idx2crd(idx, shape)
|
||||
recovered_idx = crd2idx(crd, shape)
|
||||
self.assertEqual(
|
||||
recovered_idx, idx, f"Failed roundtrip for shape {shape}, idx {idx}"
|
||||
)
|
||||
|
||||
def test_idx2crd_crd2idx_roundtrip(self):
|
||||
# Test roundtrip starting from coordinates
|
||||
test_cases = [
|
||||
(0, 5),
|
||||
(4, 5),
|
||||
((0, 0), (3, 4)),
|
||||
((1, 2), (3, 4)),
|
||||
((2, 3), (3, 4)),
|
||||
((0, 0, 0), (2, 3, 4)),
|
||||
((1, 2, 3), (2, 3, 4)),
|
||||
]
|
||||
|
||||
for crd, shape in test_cases:
|
||||
idx = crd2idx(crd, shape)
|
||||
recovered_crd = idx2crd(idx, shape)
|
||||
self.assertEqual(
|
||||
recovered_crd, crd, f"Failed roundtrip for crd {crd}, shape {shape}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -41,11 +41,11 @@ from .int_tuple import (
|
|||
IntTuple,
|
||||
is_int,
|
||||
is_tuple,
|
||||
prefix_product,
|
||||
product,
|
||||
shape_div,
|
||||
signum,
|
||||
slice_,
|
||||
suffix_product,
|
||||
tuple_max,
|
||||
)
|
||||
from .layout import (
|
||||
|
|
|
|||
|
|
@ -126,18 +126,26 @@ def shape_div(a: IntTuple, b: IntTuple) -> IntTuple:
|
|||
return (a + b - 1) // b
|
||||
|
||||
|
||||
# Exclusive prefix product with output congruent to input a
|
||||
def prefix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple:
|
||||
# Exclusive suffix product with output congruent to input a (lexicographic)
|
||||
def suffix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple:
|
||||
# TODO: With all these length asserts, may want to create a zip_strict wrapper.
|
||||
if is_tuple(a):
|
||||
if is_tuple(init): # tuple tuple
|
||||
assert len(a) == len(init)
|
||||
return tuple(prefix_product(x, i) for x, i in zip(a, init))
|
||||
return tuple(suffix_product(x, i) for x, i in zip(a, init))
|
||||
else: # tuple "int"
|
||||
# r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))]
|
||||
# Process from right to left for lexicographic ordering
|
||||
# r = [prefix_product(a[len(a)-1],init)] +
|
||||
# [prefix_product(a[i],init := init * product(a[i+1])) for i in range(len(a)-1,0)].reverse()
|
||||
r = []
|
||||
for v in a:
|
||||
r.append(prefix_product(v, init))
|
||||
init = init * product(v)
|
||||
|
||||
# Calculate products from right to left, appending to list
|
||||
for i in range(len(a) - 1, -1, -1):
|
||||
r.append(suffix_product(a[i], init))
|
||||
init = init * product(a[i])
|
||||
|
||||
# Reverse to get correct lexicographic order
|
||||
r.reverse()
|
||||
return tuple(r)
|
||||
else:
|
||||
if is_tuple(init): # "int" tuple
|
||||
|
|
@ -150,7 +158,7 @@ def idx2crd(
|
|||
idx: IntTuple, shape: IntTuple, stride: Optional[IntTuple] = None
|
||||
) -> IntTuple:
|
||||
if stride is None:
|
||||
stride = prefix_product(shape)
|
||||
stride = suffix_product(shape)
|
||||
|
||||
if is_tuple(idx):
|
||||
if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple
|
||||
|
|
@ -171,7 +179,7 @@ def crd2idx(
|
|||
crd: Optional[IntTuple], shape: IntTuple, stride: Optional[IntTuple] = None
|
||||
) -> int:
|
||||
if stride is None:
|
||||
stride = prefix_product(shape)
|
||||
stride = suffix_product(shape)
|
||||
|
||||
if is_tuple(crd):
|
||||
if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple
|
||||
|
|
@ -186,10 +194,11 @@ def crd2idx(
|
|||
if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple
|
||||
assert len(shape) == len(stride)
|
||||
result = 0
|
||||
for i in range(len(shape) - 1):
|
||||
# Process from right to left for lexicographic ordering
|
||||
for i in range(len(shape) - 1, 0, -1):
|
||||
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
|
||||
crd = crd // product(shape[i])
|
||||
return result + crd2idx(crd, shape[-1], stride[-1])
|
||||
return result + crd2idx(crd, shape[0], stride[0])
|
||||
else: # "int" "int" "int"
|
||||
assert not is_tuple(shape) and not is_tuple(stride)
|
||||
return crd * stride # all are ints after type checks
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@
|
|||
#################################################################################################
|
||||
|
||||
"""
|
||||
Definition of CuTe Layouts and functions to manipulate them
|
||||
Definition of CuTe Layouts and functions to manipulate them which works with the order
|
||||
of lexicographic instead of co-lexicographic as implemented in the original layout.py
|
||||
"""
|
||||
|
||||
from itertools import chain
|
||||
|
|
@ -45,9 +46,9 @@ from .int_tuple import (
|
|||
IntTuple,
|
||||
is_int,
|
||||
is_tuple,
|
||||
prefix_product,
|
||||
product,
|
||||
slice_,
|
||||
suffix_product,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -72,7 +73,7 @@ class Layout(LayoutBase):
|
|||
def __init__(self, _shape: IntTuple, _stride: Optional[IntTuple] = None) -> None:
|
||||
self.shape = _shape
|
||||
if _stride is None:
|
||||
self.stride = prefix_product(self.shape)
|
||||
self.stride = suffix_product(self.shape)
|
||||
else:
|
||||
self.stride = _stride
|
||||
|
||||
|
|
@ -168,7 +169,11 @@ def coalesce(layout: Layout, profile: LayoutProfile = None) -> Layout:
|
|||
|
||||
result_shape = [1]
|
||||
result_stride = [0]
|
||||
for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)):
|
||||
# Since we now follow lexicographic order, we need to process from right to left.
|
||||
# And to make implementation more efficient, we append to the end of list and reverse it in the end.
|
||||
for shape, stride in zip(
|
||||
reversed(flatten(layout.shape)), reversed(flatten(layout.stride))
|
||||
):
|
||||
# skip their shape-1s
|
||||
if shape == 1:
|
||||
continue
|
||||
|
|
@ -187,6 +192,8 @@ def coalesce(layout: Layout, profile: LayoutProfile = None) -> Layout:
|
|||
if len(result_shape) == 1:
|
||||
return Layout(result_shape[0], result_stride[0])
|
||||
else:
|
||||
result_shape.reverse()
|
||||
result_stride.reverse()
|
||||
return Layout(tuple(result_shape), tuple(result_stride))
|
||||
|
||||
|
||||
|
|
@ -241,14 +248,22 @@ def composition(layoutA: Layout, layoutB: LayoutInput) -> Layout:
|
|||
rest_shape = layoutB.shape
|
||||
rest_stride = layoutB.stride
|
||||
flat_A = coalesce(layoutA)
|
||||
# when left layout is multi-dimensional sublayout, aka, self = (a,b,...,c):(x,y,...,z), layout = s:d,
|
||||
# for integral s and d means that we want:
|
||||
# (1) “remove” the first d elements from left, starting from rightmost. (This will increase the stride.)
|
||||
# (2) “keep” the first s of those strided elements. (This does not affect the stride.)
|
||||
# For example, if self = (6,2):(2,1), layout = (3:2)
|
||||
# Step 1: remove the first 2 elements from self with stride increase, i.e., (6,2):(2,1) -> (6,1):(2,2)
|
||||
# Step 2: keep the first 3 of those strided elements, i.e., (6,1):(2,2) -> (3,1):(2,2)
|
||||
# Because we are going lexicographically, we go through left layout from right to left.
|
||||
for curr_shape, curr_stride in zip(
|
||||
flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]
|
||||
reversed(flatten(flat_A.shape)[1:]), reversed(flatten(flat_A.stride)[1:])
|
||||
):
|
||||
assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 # type: ignore[operator]
|
||||
new_shape = min(max(1, curr_shape // rest_stride), rest_shape) # type: ignore[operator]
|
||||
|
||||
if new_shape != 1:
|
||||
result_shape.append(new_shape)
|
||||
result_shape.append(new_shape) # Append to end, will reverse later
|
||||
result_stride.append(rest_stride * curr_stride)
|
||||
|
||||
rest_shape = rest_shape // new_shape # type: ignore[operator]
|
||||
|
|
@ -256,9 +271,16 @@ def composition(layoutA: Layout, layoutB: LayoutInput) -> Layout:
|
|||
-rest_stride // curr_shape # type: ignore[operator]
|
||||
) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride)
|
||||
|
||||
# When left has single-size sublayout or reach the last sublayout, aka, left = a:b, layout = s:d,
|
||||
# the result is rather trivial: left o layout = a:b o s:d = s:(b*d).
|
||||
# For example, if self = (6:2), layout = (3:2), the result is (3:(2*2)) = (3:4).
|
||||
if rest_shape != 1 or len(result_shape) == 0:
|
||||
result_shape.append(rest_shape)
|
||||
result_stride.append(rest_stride * flatten(flat_A.stride)[-1])
|
||||
result_shape.append(rest_shape) # Append to end, will reverse later
|
||||
result_stride.append(rest_stride * flatten(flat_A.stride)[0])
|
||||
|
||||
# Reverse the lists because we build lists in reverse order (append to end), this way it is more efficient.
|
||||
result_shape.reverse()
|
||||
result_stride.reverse()
|
||||
|
||||
if len(result_shape) == 1:
|
||||
return Layout(result_shape[0], result_stride[0]) # type: ignore[arg-type]
|
||||
|
|
@ -290,6 +312,10 @@ def complement(layout: LayoutOrIntTuple, max_idx: int = 1) -> Layout:
|
|||
|
||||
result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
|
||||
result_stride.append(current_idx)
|
||||
# This is different from original pycute implementation, because we want to follow the lexicographic order here
|
||||
# where the right-most dimension is the innermost dimension (smallest stride).
|
||||
result_shape.reverse()
|
||||
result_stride.reverse()
|
||||
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
|
@ -307,7 +333,7 @@ def right_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]:
|
|||
|
||||
flat_shape = flatten(layout.shape) # type: ignore[union-attr]
|
||||
flat_stride = flatten(layout.stride) # type: ignore[union-attr]
|
||||
sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape))) # type: ignore[arg-type]
|
||||
sorted_DSA = sorted(zip(flat_stride, flat_shape, suffix_product(flat_shape))) # type: ignore[arg-type]
|
||||
for stride, shape, rstride in sorted_DSA:
|
||||
if shape == 1:
|
||||
continue
|
||||
|
|
@ -318,6 +344,8 @@ def right_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]:
|
|||
result_stride.append(rstride)
|
||||
current_idx = shape * stride
|
||||
|
||||
result_shape.reverse()
|
||||
result_stride.reverse()
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
||||
|
|
@ -327,7 +355,7 @@ def left_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]:
|
|||
return None
|
||||
elif is_int(layout):
|
||||
return Layout(layout)
|
||||
return right_inverse(make_layout(layout, complement(layout))) # type: ignore[arg-type]
|
||||
return right_inverse(make_layout(complement(layout), layout)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# Split a layout by the composition of B and the "rest"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user