diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a6a8bfe13d7..5556d3ee4ed 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1666,6 +1666,52 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): out = func(query, key, value, block_mask=block_mask) out.sum().backward() + @supported_platform + @common_utils.parametrize("mode", ["eager", "inductor"]) + @common_utils.parametrize( + "permute_order", + [ + (0, 1, 2, 3), # Default order + (1, 0, 2, 3), # Reverse order + (0, 2, 1, 3), # Mixed order + (2, 0, 1, 3), # Another mixed order + ], + ) + @common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)]) + def test_flex_attention_stride_ordering(self, mode, permute_order, shape): + from torch._inductor.ir import get_stride_order + + # Setup + make_tensor = functools.partial( + torch.randn, + shape, + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + + # Create and permute tensors + query, key, value = make_tensor(), make_tensor(), make_tensor() + query = query.permute(permute_order) + key = key.permute(permute_order) + value = value.permute(permute_order) + + if mode == "inductor": + func = torch.compile(flex_attention, backend=mode, fullgraph=True) + else: + func = flex_attention + + out = func(query, key, value) + + out_stride_order = get_stride_order(out.stride()) + query_stride_order = get_stride_order(query.stride()) + + self.assertEqual( + out_stride_order, + query_stride_order, + f"Stride order mismatch: out {out_stride_order}, query {query_stride_order}", + ) + @supported_platform @common_utils.parametrize("compile", [True, False]) def test_fully_masked_out_rows_0_check(self, compile: bool): diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 3783b29f117..b28f657b9d5 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import math -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Sequence, Tuple, Union import torch import torch.utils._pytree as pytree @@ -23,6 +23,53 @@ from torch.fx.graph_module import GraphModule from torch.overrides import TorchFunctionMode +# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import +def _construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor: + """ + Create a new tensor with the same data and shape as the input, + but with strides permuted based on the input tensor's stride order. + + Args: + out (torch.Tensor): The output tensor of attention. + query_strides (List[int]): The stride order of the input query tensor + + Returns: + torch.Tensor: A new tensor with same shape and data as the input, + but with strides permuted based on the query tensor's stride order. + """ + from torch._inductor.ir import get_stride_order, stride_order2fill_order + + stride_order = get_stride_order(query_strides) + fill_order = stride_order2fill_order(stride_order) + assert out.storage_offset() == 0, "Only support storage_offset == 0" + out_strides = _construct_strides(out.shape, fill_order) + new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides) + new_out.copy_(out) + return new_out + + class TransformGetItemToIndex(TorchFunctionMode): # This is needed since we want to support calling # A[q_idx], where q_idx is a scalar tensor in score_mod. @@ -244,7 +291,7 @@ def sdpa_dense( score_mod_other_buffers, mask_mod_other_buffers, ) - out = out.contiguous() + out = _permute_strides(out, query.stride()) return out, lse @@ -432,7 +479,9 @@ def flex_attention_fake_tensor_mode( batch_size, num_heads, seq_len_q, dtype=torch.float32 ) out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) - return query.new_empty(out_shape), logsumexp + out = query.new_empty(out_shape) + out = _permute_strides(out, query.stride()) + return out, logsumexp # ---------------------------- Autograd Implementation ---------------------------- diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 0e3d6ae4918..d6dfca28662 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -3,7 +3,7 @@ import logging import math -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Sequence, Tuple import sympy @@ -17,9 +17,11 @@ from ..ir import ( ExternKernel, FixedLayout, FlexibleLayout, + get_stride_order, InputBuffer, IRNode, StorageBox, + stride_order2fill_order, Subgraph, TensorBox, ) @@ -29,6 +31,29 @@ from ..select_algorithm import autotune_select_algorithm, realize_inputs, Triton log = logging.getLogger(__name__) aten = torch.ops.aten +Expr = sympy.Expr + + +def construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta): @@ -761,11 +786,18 @@ def flex_attention( # This works because only the last dim differs and we check it is contiguous. q_strides = query.get_stride() assert q_strides[-1] == 1, "Query must be contiguous in the last dimension" + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + stride_order = get_stride_order(query.get_stride()) + fill_order = stride_order2fill_order(stride_order) + out_strides = construct_strides(out_size, fill_order) + layout = FixedLayout( query.get_device(), query.get_dtype(), [B, Hq, seq_len_q, v_head_dim], - query.get_stride(), + stride=out_strides, ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = [B, Hq, seq_len_q]