Split batch-num-heads grid dim between y and z (#157745)

for #157018

doesn't totally fix the problem but should help alot

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157745
Approved by: https://github.com/Chillee
This commit is contained in:
drisspg 2025-07-07 16:09:00 -07:00 committed by PyTorch MergeBot
parent 39a8f66d59
commit 987314aa96
2 changed files with 42 additions and 12 deletions

View File

@ -4205,6 +4205,36 @@ class GraphModule(torch.nn.Module):
# vanilla compiled vs TMA compiled # vanilla compiled vs TMA compiled
torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1) torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1)
@supported_platform
@skip_on_cpu
def test_large_batch_heads_grid_dimension(self, device):
B, H, S, D = 22720, 3, 64, 32
make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
device=device,
dtype=torch.float16,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True)
out_compiled = flex_compile(query, key, value)
self.assertEqual(out_compiled.shape, (B, H, S, D))
grad_output = torch.randn_like(out_compiled)
out_compiled.backward(grad_output)
self.assertIsNotNone(query.grad)
self.assertIsNotNone(key.grad)
self.assertIsNotNone(value.grad)
self.assertEqual(query.grad.shape, query.shape)
self.assertEqual(key.grad.shape, key.shape)
self.assertEqual(value.grad.shape, value.shape)
class TestBlockMask(InductorTestCase): class TestBlockMask(InductorTestCase):
def setUp(self): def setUp(self):

View File

@ -98,11 +98,11 @@ def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]):
@SymbolicGridFn @SymbolicGridFn
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv):
"""How is this kernel parallelized? """How is this kernel parallelized?
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) We create a grid of (ceil_div(n_queries, query_block_size), batch_size, num_heads)
Each block is responsible for iterating over blocks of keys and values calculating Each block is responsible for iterating over blocks of keys and values calculating
the final attention output. the final attention output.
""" """
return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1) return (cdiv(num_queries, meta["BLOCK_M"]), batch_size, q_heads)
def create_placeholder( def create_placeholder(
@ -390,8 +390,8 @@ compute_flex_attention = r"""
MATMUL_PRECISION = Q.dtype.element_ty MATMUL_PRECISION = Q.dtype.element_ty
q_start = tl.program_id(0) q_start = tl.program_id(0)
off_zq = tl.program_id(1) // HQ off_zq = tl.program_id(1)
off_hq = tl.program_id(1) % HQ off_hq = tl.program_id(2)
# Setting up the TMA descriptors for Q, K, V # Setting up the TMA descriptors for Q, K, V
@ -573,8 +573,8 @@ compute_flex_attention = r"""
l_i = tl.where(l_i == 0.0, 1, l_i) l_i = tl.where(l_i == 0.0, 1, l_i)
acc = acc / l_i[:, None] acc = acc / l_i[:, None]
idx_zq = tl.program_id(1) // HQ idx_zq = tl.program_id(1)
idx_hq = tl.program_id(1) % HQ idx_hq = tl.program_id(2)
idx_m = offs_m[:, None] idx_m = offs_m[:, None]
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
@ -583,7 +583,7 @@ compute_flex_attention = r"""
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
if OUTPUT_LOGSUMEXP: if OUTPUT_LOGSUMEXP:
off_hz = tl.program_id(1) off_hz = off_zq * HQ + off_hq
l_ptrs = LSE + off_hz * Q_LEN + offs_m l_ptrs = LSE + off_hz * Q_LEN + offs_m
lse = m_i + tl.math.log2(l_i) lse = m_i + tl.math.log2(l_i)
if IS_DIVISIBLE: if IS_DIVISIBLE:
@ -1572,6 +1572,7 @@ def flex_attention_backward_grid(
batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
): ):
"""How is this kernel parallelized? """How is this kernel parallelized?
We create a grid of (ceil_div(n_queries, query_block_size) * heads_ratio + ceil_div(n_kv, kv_block_size), batch_size, kv_heads)
Currently this is only parallelizing over batch* kv_heads, but we can, and want to Currently this is only parallelizing over batch* kv_heads, but we can, and want to
parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size). parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
To do this will either require atomic updates to some grad values or to have a two pass kernel design. To do this will either require atomic updates to some grad values or to have a two pass kernel design.
@ -1581,8 +1582,8 @@ def flex_attention_backward_grid(
return ( return (
triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads) triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
+ triton.cdiv(num_key_value, meta["BLOCK_N1"]), + triton.cdiv(num_key_value, meta["BLOCK_N1"]),
1, batch_size,
batch_size * kv_heads, kv_heads,
) )
@ -1647,9 +1648,8 @@ flex_attention_backward_template = TritonTemplate(
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_hz = tl.program_id(2) off_zq = tl.program_id(1) # q batch idx
off_zq = off_hz // HKV # q batch idx off_hkv = tl.program_id(2) # kv head idx
off_hkv = off_hz % HKV # kv head idx
off_zkv = off_zq % ZKV # kv batch idx off_zkv = off_zq % ZKV # kv batch idx
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}