From 51667435f50c025ca3655ff0eeb917b4ef0ffb78 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 31 Oct 2025 18:05:10 +0000 Subject: [PATCH] [FlexFlash] Wire up mask_mod + blockmask to flash impl (#166359) I have some local changes that I need to push to flash first https://github.com/Dao-AILab/flash-attention/pull/1970 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166359 Approved by: https://github.com/v0i0 --- test/inductor/test_flex_flash.py | 171 +++++++++++++++--- .../codegen/cutedsl/cutedsl_kernel.py | 62 +++---- torch/_inductor/kernel/flex/flex_attention.py | 42 +++-- .../kernel/flex/flex_flash_attention.py | 71 ++++---- .../flex/templates/flash_attention.py.jinja | 30 ++- 5 files changed, 260 insertions(+), 116 deletions(-) diff --git a/test/inductor/test_flex_flash.py b/test/inductor/test_flex_flash.py index f75eff65382..b029d880ebf 100644 --- a/test/inductor/test_flex_flash.py +++ b/test/inductor/test_flex_flash.py @@ -122,16 +122,52 @@ def cuda_kernel_profiler(kernel_pattern="flash_attncute"): result["found"] = any(kernel_pattern in name for name in kernel_names) -def flash_vs_triton(q, k, v, score_mod=None, rtol=5e-3, atol=5e-3): +def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2): compiled_fn = torch.compile(flex_attention) + + out_ref_fp32 = flex_attention( + q.to(torch.float32), + k.to(torch.float32), + v.to(torch.float32), + score_mod=score_mod, + block_mask=block_mask, + ).to(q.dtype) + out_flash = compiled_fn( - q, k, v, score_mod=score_mod, kernel_options={"force_flash": True} + q, + k, + v, + score_mod=score_mod, + block_mask=block_mask, + kernel_options={"force_flash": True}, ) - out_no_flash = compiled_fn( - q, k, v, score_mod=score_mod, kernel_options={"force_flash": False} + out_triton = compiled_fn( + q, + k, + v, + score_mod=score_mod, + block_mask=block_mask, + kernel_options={"force_flash": False}, ) - torch.testing.assert_close(out_flash, out_no_flash, rtol=rtol, atol=atol) - return out_flash, out_no_flash + + assert out_flash.shape == out_ref_fp32.shape == out_triton.shape + assert not torch.isnan(out_flash).any() + assert not torch.isnan(out_triton).any() + assert not torch.isnan(out_ref_fp32).any() + assert torch.isfinite(out_flash).all() + assert torch.isfinite(out_triton).all() + assert torch.isfinite(out_ref_fp32).all() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + + triton_error = (out_triton - out_ref_fp32).abs().max().item() + flash_error = (out_flash - out_ref_fp32).abs().max().item() + + assert flash_error <= rtol * triton_error + fwd_atol, ( + f"Flash error {flash_error:.2e} exceeds {rtol}x Triton error {triton_error:.2e} + {fwd_atol:.2e}" + ) + + return out_flash, out_triton, out_ref_fp32 def name_fn(score_mod): @@ -162,26 +198,6 @@ class TestFlexFlash(InductorTestCase): q, k, v = create_test_tensors(seq_len=seq_len, dtype=dtype, device=device) flash_vs_triton(q, k, v, score_mod=_causal) - @dtypes(torch.float16, torch.bfloat16) - def test_force_flash_error_with_block_mask(self, device, dtype): - """Test that force_flash=True raises error when BlockMask is provided.""" - q, k, v = create_test_tensors(dtype=dtype, device=device) - - # Create a causal block mask - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) - - compiled_fn = torch.compile(flex_attention) - with self.assertRaisesRegex( - RuntimeError, - r"force_flash=True but flash attention cannot be used.*BlockMask.*not supported", - ): - compiled_fn( - q, k, v, block_mask=block_mask, kernel_options={"force_flash": True} - ) - @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_kernel_called(self, device, dtype): """Test that flash attention kernel is actually called when force_flash=True.""" @@ -257,7 +273,6 @@ class TestFlexFlash(InductorTestCase): """Test that force_flash=True raises error when tensor requires gradients.""" q, k, v = create_test_tensors(dtype=dtype, device=device) - # Create a score mod with requires_grad tensor bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True) def score_mod_with_grad(score, b, h, q_idx, kv_idx): @@ -276,6 +291,108 @@ class TestFlexFlash(InductorTestCase): kernel_options={"force_flash": True}, ) + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_with_block_mask(self, device, dtype): + """Test flash attention with block mask and mask_mod.""" + q, k, v = create_test_tensors(dtype=dtype, device=device) + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) + flash_vs_triton(q, k, v, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_block_mask_with_score_mod(self, device, dtype): + """Test flash attention with both block mask and score_mod.""" + q, k, v = create_test_tensors(dtype=dtype, device=device) + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) + flash_vs_triton(q, k, v, score_mod=_times_two, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_with_mask_mod_buffer(self, device, dtype): + """Test flash attention with mask_mod that loads from buffer.""" + q, k, v = create_test_tensors( + batch_size=2, num_heads=4, dtype=dtype, device=device + ) + + mask_bias = torch.randn(4, device=device, dtype=dtype) * 0.1 + + def custom_mask(b, h, q_idx, kv_idx): + bias_value = mask_bias[h] + return (q_idx >= kv_idx) | (bias_value > 0) + + block_mask = create_block_mask(custom_mask, 2, 4, 512, 512, device=device) + flash_vs_triton(q, k, v, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_mask_mod_with_dual_buffers(self, device, dtype): + """Mask modifier should support multiple captured buffers.""" + batch_size, num_heads, seq_len = 2, 4, 512 + q, k, v = create_test_tensors( + batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device + ) + + head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2 + batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.2 + + def dual_buffer_mask(b, h, q_idx, kv_idx): + head_term = head_bias[h] + batch_term = batch_bias[b] + causal = q_idx >= kv_idx + bias_cond = (head_term + batch_term).to(torch.float32) > 0 + return causal | bias_cond + + block_mask = create_block_mask( + dual_buffer_mask, batch_size, num_heads, seq_len, seq_len, device=device + ) + flash_vs_triton(q, k, v, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_score_mod_with_many_buffer_indexing(self, device, dtype): + batch_size, num_heads, seq_len = 2, 4, 512 + q, k, v = create_test_tensors( + batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device + ) + + head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.15 + query_scale = torch.randn(seq_len, device=device, dtype=dtype) * 0.05 + kv_scale = torch.randn(seq_len, device=device, dtype=dtype) * 0.05 + batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 + + def complex_score(score, b, h, q_idx, kv_idx): + head_term = head_bias[h] + query_term = query_scale[q_idx] + kv_term = kv_scale[kv_idx] + batch_term = batch_bias[b] + return score + head_term + query_term - kv_term + batch_term + + flash_vs_triton(q, k, v, score_mod=complex_score) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_with_score_and_mask_buffers(self, device, dtype): + """Test flash attention with both score_mod and mask_mod using buffers.""" + q, k, v = create_test_tensors( + batch_size=2, num_heads=4, dtype=dtype, device=device + ) + + score_bias = torch.randn(4, device=device, dtype=dtype) * 0.2 + mask_bias = torch.randn(4, device=device, dtype=dtype) * 0.1 + + def score_with_buffer(score, b, h, q_idx, kv_idx): + return score + score_bias[h] + + def mask_with_buffer(b, h, q_idx, kv_idx): + bias_value = mask_bias[h] + return (q_idx >= kv_idx) | (bias_value > 0) + + block_mask = create_block_mask(mask_with_buffer, 2, 4, 512, 512, device=device) + flash_vs_triton(q, k, v, score_mod=score_with_buffer, block_mask=block_mask) + instantiate_device_type_tests(TestFlexFlash, globals(), only_for="cuda") diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py index ac8ce6f9176..af7e2548ec4 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py @@ -65,6 +65,10 @@ class CuteDSLSubgraphInfo: body: IndentedBuffer template_mask: Optional[str] = None template_out: Optional[str] = None + cse: Optional[CSE[Any]] = None + + def __post_init__(self): + self.only_copy_if_non_none_fields = ("cse",) def to_dict(self): return { @@ -191,10 +195,15 @@ class CuteDSLTemplateKernel(Kernel): body=IndentedBuffer(), template_mask=None, template_out=None, + cse=None, ) subgraph = self.subgraph_bodies[body_name] for key, value in subgraph.to_dict().items(): + if value is None and key in getattr( + subgraph, "only_copy_if_non_none_fields", () + ): + continue setattr(self, key, value) try: @@ -212,15 +221,17 @@ class CuteDSLTemplateKernel(Kernel): setattr(self, key, value) @contextlib.contextmanager - def create_subgraph_body(self, body_name: str): + def create_subgraph_body(self, body_name: str, *, clear_cse: bool = False): """Create a new subgraph body for template processing.""" assert body_name not in self.subgraph_bodies, ( f"Subgraph body '{body_name}' already exists" ) + new_cse = self.cse.clone() if clear_cse else None self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( body=IndentedBuffer(), template_mask=None, template_out=None, + cse=new_cse, ) with self.set_subgraph_body(body_name): yield @@ -294,7 +305,8 @@ class CuteDSLTemplateKernel(Kernel): # Register the hook and return placeholder placeholder = "" - assert placeholder not in self.render_hooks + # TODO: I think double invoking is fine for this specific hook + # assert placeholder not in self.render_hooks self.render_hooks[placeholder] = hook return placeholder @@ -330,7 +342,7 @@ class CuteDSLTemplateKernel(Kernel): while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: num += 1 - with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}", clear_cse=True): subgraph = self._get_subgraph(subgraph_number) modification_handler = ModificationWrapperCuteDSL( self, subgraph_number, fixed_inputs, mask @@ -429,40 +441,20 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined # val_frag[0] = tensor[index] # result = val_frag.load() - index_frag = self.kernel.cse.generate( - self.kernel.body, - "cute.make_fragment(1, cutlass.Int32)", - dtype=torch.int32, - bounds=ValueRanges.unknown(), + index_frag = self.kernel.cse.newvar(dtype=torch.int32) + self.kernel.body.writeline( + f"{index_frag} = cute.make_fragment(1, cutlass.Int32)" + ) + self.kernel.body.writeline(f"{index_frag}.store({index_str})") + + val_frag = self.kernel.cse.newvar(dtype=var_dtype) + self.kernel.body.writeline( + f"{val_frag} = cute.make_fragment(1, {cute_dtype})" ) - self.kernel.cse.generate( - self.kernel.body, - f"{index_frag}.store({index_str})", - dtype=torch.int32, - bounds=ValueRanges.unknown(), - ) - - val_frag = self.kernel.cse.generate( - self.kernel.body, - f"cute.make_fragment(1, {cute_dtype})", - dtype=var_dtype, - bounds=ValueRanges.unknown(), - ) - - index_var = self.kernel.cse.generate( - self.kernel.body, - f"{index_frag}[0]", - dtype=torch.int32, - bounds=ValueRanges.unknown(), - ) - - self.kernel.cse.generate( - self.kernel.body, - f"{val_frag}[0] = ({var}[{index_var}])", - dtype=var_dtype, - bounds=ValueRanges.unknown(), - ) + index_var = self.kernel.cse.newvar(dtype=torch.int32) + self.kernel.body.writeline(f"{index_var} = {index_frag}[0]") + self.kernel.body.writeline(f"{val_frag}[0] = ({var}[{index_var}])") final_expr = f"{val_frag}.load()" diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index bc148ebc207..7697aad5767 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -193,24 +193,6 @@ def flex_attention( score_mod_other_buffers, mask_mod_other_buffers, ) - if _use_flex_flash_attention( - subgraph, - mask_graph, - kernel_options, - num_score_mod_placeholders=len(placeholder_inps), - ): - return create_flex_flash_attention_kernel( - query, - key, - value, - block_mask, - scale, - kernel_options, - subgraph_buffer, - mask_graph_buffer, - score_mod_other_buffers, - mask_mod_other_buffers, - ) ( query, @@ -240,6 +222,30 @@ def flex_attention( ] ) + if _use_flex_flash_attention( + subgraph, + mask_graph, + kernel_options, + num_score_mod_placeholders=len(placeholder_inps), + ): + return create_flex_flash_attention_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + mask_graph=mask_graph, + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 946cff598bc..35c7853d356 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -56,10 +56,8 @@ def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int): return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:]) -def is_trivial_graph( - graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int -): - """Check if the flex graphs are compatible with Flash Attention.""" +def is_trivial_mask_graph(graph_module: GraphModule) -> bool: + """Mask graph is trivial when it only gates via the default full op.""" graph = graph_module.graph nodes = list(graph.nodes) placeholders = [n for n in nodes if n.op == "placeholder"] @@ -67,14 +65,16 @@ def is_trivial_graph( assert len(output) == 1, "Got graph w/ multiple outputs" output_val = output[0].args[0] - if is_score_graph: - if input_buffers_require_grads(graph_module, num_score_mod_placeholders): - return False - return True # party on garth # mask mod graph is empty if we have 4 inputs and full_default output return len(placeholders) == 4 and output_val.target is torch.ops.aten.full.default +@functools.lru_cache(maxsize=1) +def _supports_nontrivial_mask_graphs() -> bool: + """Currently only supported on Hopper (SM90) GPUs.""" + return torch.cuda.get_device_capability()[0] == 9 + + def _can_use_flex_flash_attention( subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int ) -> tuple[bool, str]: @@ -91,32 +91,15 @@ def _can_use_flex_flash_attention( False, "Input buffers require gradients (not supported by flash attention)", ) + mask_trivial = is_trivial_mask_graph(mask_graph.graph_module) - score_trivial = is_trivial_graph( - subgraph.graph_module, - is_score_graph=True, - num_score_mod_placeholders=num_score_mod_placeholders, - ) - mask_trivial = is_trivial_graph( - mask_graph.graph_module, - is_score_graph=False, - num_score_mod_placeholders=num_score_mod_placeholders, - ) + if mask_trivial: + return True, "" - if not score_trivial and not mask_trivial: + if not _supports_nontrivial_mask_graphs(): return ( False, - "Both score and mask graphs are too complex for flash attention (require simple operations only)", - ) - elif not score_trivial: - return ( - False, - "Score modification captured tensors that require gradients (not supported by flash attention)", - ) - elif not mask_trivial: - return ( - False, - "A non None BlockMask was passed to flex attention (not supported by flash attention yet)", + "NYI: Non-trivial mask graphs only supported on Hopper (SM90) for flash attention", ) return True, "" @@ -154,6 +137,11 @@ def create_flex_flash_attention_kernel( mask_graph_buffer: SubgraphResults, score_mod_other_buffers: list[TensorBox], mask_mod_other_buffers: list[TensorBox], + kv_num_blocks: TensorBox | None, + kv_indices: TensorBox | None, + full_kv_num_blocks: TensorBox | None, + full_kv_indices: TensorBox | None, + mask_graph: Subgraph, ) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox | ShapeAsConstantBuffer]: """Create a flex flash attention kernel using CuteDSL template.""" if not ensure_flash_available(): @@ -193,17 +181,34 @@ def create_flex_flash_attention_kernel( stride=[sympy.sympify(s) for s in output.get_stride()], ) + # Used to check if we can skip block sparse impl + mask_graph_is_trivial = is_trivial_mask_graph(mask_graph.graph_module) + + needs_block_mask = not mask_graph_is_trivial + has_full_blocks = full_kv_num_blocks is not None + choices: list[Any] = [] - causal = kernel_options.get("causal", False) assert flash_attention_cutedsl_template is not None + + input_nodes = [query, key, value, lse] + if has_full_blocks: + input_nodes.extend( + [kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices] + ) + + if needs_block_mask and not has_full_blocks: + raise NotImplementedError( + "Flash attention with block mask but without full blocks is not supported yet" + ) + error = flash_attention_cutedsl_template.maybe_append_choice( choices, - input_nodes=[query, key, value, lse], + input_nodes=input_nodes, layout=output_layout, mutated_inputs=[lse], subgraphs=[subgraph_buffer, mask_graph_buffer], SM_SCALE=scale, - CAUSAL=causal, + NEEDS_BLOCK_MASK=needs_block_mask, ) if error or not choices: diff --git a/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja b/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja index d4f29bb8470..252e324554f 100644 --- a/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja +++ b/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja @@ -1,6 +1,10 @@ - +{% if NEEDS_BLOCK_MASK %} +{{def_kernel("Q", "K", "V", "LOGSUMEXP", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} +{% else %} {{def_kernel("Q", "K", "V", "LOGSUMEXP")}} +{% endif %} from flash_attn.cute.interface import _flash_attn_fwd + from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch # Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D) q_transposed = Q.transpose(1, 2) @@ -26,6 +30,25 @@ output = {{get_output()}} output_transposed = output.transpose(1, 2) + {% if NEEDS_BLOCK_MASK %} + @cute.jit + def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors): + {{unpack_buffers("aux_tensors", indent_width=8)}} + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + b="b_idx", + h="h_idx", + m="q_idx", + n="kv_idx", + ) | indent_except_first(2) }} + return mask_mod_output + block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX) + {% else %} + block_sparse_tensors = None + mask_mod = None + {% endif %} + # Collect any additional tensor buffers that were added during modifications {% set tensor_buffers = get_tensor_buffers() -%} {% if tensor_buffers -%} @@ -41,10 +64,11 @@ k_transposed, v_transposed, softmax_scale={{SM_SCALE}}, - causal={{CAUSAL}}, return_lse=True, score_mod=score_mod, + mask_mod=mask_mod, out=output_transposed, lse=LOGSUMEXP, + block_sparse_tensors=block_sparse_tensors, aux_tensors=buffers - ) \ No newline at end of file + )