From 987314aa96fdf8aa051e3643b26f4209b7fe166d Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 7 Jul 2025 16:09:00 -0700 Subject: [PATCH] Split batch-num-heads grid dim between y and z (#157745) for #157018 doesn't totally fix the problem but should help alot Pull Request resolved: https://github.com/pytorch/pytorch/pull/157745 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 30 ++++++++++++++++++++++++ torch/_inductor/kernel/flex_attention.py | 24 +++++++++---------- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 4d14555800c..e9c6f3b4720 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -4205,6 +4205,36 @@ class GraphModule(torch.nn.Module): # vanilla compiled vs TMA compiled torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1) + @supported_platform + @skip_on_cpu + def test_large_batch_heads_grid_dimension(self, device): + B, H, S, D = 22720, 3, 64, 32 + + make_tensor = functools.partial( + torch.randn, + (B, H, S, D), + device=device, + dtype=torch.float16, + requires_grad=True, + ) + + query, key, value = make_tensor(), make_tensor(), make_tensor() + + flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True) + out_compiled = flex_compile(query, key, value) + + self.assertEqual(out_compiled.shape, (B, H, S, D)) + + grad_output = torch.randn_like(out_compiled) + out_compiled.backward(grad_output) + + self.assertIsNotNone(query.grad) + self.assertIsNotNone(key.grad) + self.assertIsNotNone(value.grad) + self.assertEqual(query.grad.shape, query.shape) + self.assertEqual(key.grad.shape, key.shape) + self.assertEqual(value.grad.shape, value.shape) + class TestBlockMask(InductorTestCase): def setUp(self): diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 99e869dc8fd..e6f150c1e98 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -98,11 +98,11 @@ def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]): @SymbolicGridFn def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): """How is this kernel parallelized? - We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) + We create a grid of (ceil_div(n_queries, query_block_size), batch_size, num_heads) Each block is responsible for iterating over blocks of keys and values calculating the final attention output. """ - return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1) + return (cdiv(num_queries, meta["BLOCK_M"]), batch_size, q_heads) def create_placeholder( @@ -390,8 +390,8 @@ compute_flex_attention = r""" MATMUL_PRECISION = Q.dtype.element_ty q_start = tl.program_id(0) - off_zq = tl.program_id(1) // HQ - off_hq = tl.program_id(1) % HQ + off_zq = tl.program_id(1) + off_hq = tl.program_id(2) # Setting up the TMA descriptors for Q, K, V @@ -573,8 +573,8 @@ compute_flex_attention = r""" l_i = tl.where(l_i == 0.0, 1, l_i) acc = acc / l_i[:, None] - idx_zq = tl.program_id(1) // HQ - idx_hq = tl.program_id(1) % HQ + idx_zq = tl.program_id(1) + idx_hq = tl.program_id(2) idx_m = offs_m[:, None] idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] @@ -583,7 +583,7 @@ compute_flex_attention = r""" {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} if OUTPUT_LOGSUMEXP: - off_hz = tl.program_id(1) + off_hz = off_zq * HQ + off_hq l_ptrs = LSE + off_hz * Q_LEN + offs_m lse = m_i + tl.math.log2(l_i) if IS_DIVISIBLE: @@ -1572,6 +1572,7 @@ def flex_attention_backward_grid( batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta ): """How is this kernel parallelized? + We create a grid of (ceil_div(n_queries, query_block_size) * heads_ratio + ceil_div(n_kv, kv_block_size), batch_size, kv_heads) Currently this is only parallelizing over batch* kv_heads, but we can, and want to parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size). To do this will either require atomic updates to some grad values or to have a two pass kernel design. @@ -1581,8 +1582,8 @@ def flex_attention_backward_grid( return ( triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads) + triton.cdiv(num_key_value, meta["BLOCK_N1"]), - 1, - batch_size * kv_heads, + batch_size, + kv_heads, ) @@ -1647,9 +1648,8 @@ flex_attention_backward_template = TritonTemplate( NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) - off_hz = tl.program_id(2) - off_zq = off_hz // HKV # q batch idx - off_hkv = off_hz % HKV # kv head idx + off_zq = tl.program_id(1) # q batch idx + off_hkv = tl.program_id(2) # kv head idx off_zkv = off_zq % ZKV # kv batch idx SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}