mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Split batch-num-heads grid dim between y and z (#157745)
for #157018 doesn't totally fix the problem but should help alot Pull Request resolved: https://github.com/pytorch/pytorch/pull/157745 Approved by: https://github.com/Chillee
This commit is contained in:
parent
39a8f66d59
commit
987314aa96
|
|
@ -4205,6 +4205,36 @@ class GraphModule(torch.nn.Module):
|
||||||
# vanilla compiled vs TMA compiled
|
# vanilla compiled vs TMA compiled
|
||||||
torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1)
|
torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@skip_on_cpu
|
||||||
|
def test_large_batch_heads_grid_dimension(self, device):
|
||||||
|
B, H, S, D = 22720, 3, 64, 32
|
||||||
|
|
||||||
|
make_tensor = functools.partial(
|
||||||
|
torch.randn,
|
||||||
|
(B, H, S, D),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
||||||
|
|
||||||
|
flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True)
|
||||||
|
out_compiled = flex_compile(query, key, value)
|
||||||
|
|
||||||
|
self.assertEqual(out_compiled.shape, (B, H, S, D))
|
||||||
|
|
||||||
|
grad_output = torch.randn_like(out_compiled)
|
||||||
|
out_compiled.backward(grad_output)
|
||||||
|
|
||||||
|
self.assertIsNotNone(query.grad)
|
||||||
|
self.assertIsNotNone(key.grad)
|
||||||
|
self.assertIsNotNone(value.grad)
|
||||||
|
self.assertEqual(query.grad.shape, query.shape)
|
||||||
|
self.assertEqual(key.grad.shape, key.shape)
|
||||||
|
self.assertEqual(value.grad.shape, value.shape)
|
||||||
|
|
||||||
|
|
||||||
class TestBlockMask(InductorTestCase):
|
class TestBlockMask(InductorTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
|
||||||
|
|
@ -98,11 +98,11 @@ def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]):
|
||||||
@SymbolicGridFn
|
@SymbolicGridFn
|
||||||
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv):
|
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv):
|
||||||
"""How is this kernel parallelized?
|
"""How is this kernel parallelized?
|
||||||
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
|
We create a grid of (ceil_div(n_queries, query_block_size), batch_size, num_heads)
|
||||||
Each block is responsible for iterating over blocks of keys and values calculating
|
Each block is responsible for iterating over blocks of keys and values calculating
|
||||||
the final attention output.
|
the final attention output.
|
||||||
"""
|
"""
|
||||||
return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
|
return (cdiv(num_queries, meta["BLOCK_M"]), batch_size, q_heads)
|
||||||
|
|
||||||
|
|
||||||
def create_placeholder(
|
def create_placeholder(
|
||||||
|
|
@ -390,8 +390,8 @@ compute_flex_attention = r"""
|
||||||
MATMUL_PRECISION = Q.dtype.element_ty
|
MATMUL_PRECISION = Q.dtype.element_ty
|
||||||
|
|
||||||
q_start = tl.program_id(0)
|
q_start = tl.program_id(0)
|
||||||
off_zq = tl.program_id(1) // HQ
|
off_zq = tl.program_id(1)
|
||||||
off_hq = tl.program_id(1) % HQ
|
off_hq = tl.program_id(2)
|
||||||
|
|
||||||
|
|
||||||
# Setting up the TMA descriptors for Q, K, V
|
# Setting up the TMA descriptors for Q, K, V
|
||||||
|
|
@ -573,8 +573,8 @@ compute_flex_attention = r"""
|
||||||
l_i = tl.where(l_i == 0.0, 1, l_i)
|
l_i = tl.where(l_i == 0.0, 1, l_i)
|
||||||
|
|
||||||
acc = acc / l_i[:, None]
|
acc = acc / l_i[:, None]
|
||||||
idx_zq = tl.program_id(1) // HQ
|
idx_zq = tl.program_id(1)
|
||||||
idx_hq = tl.program_id(1) % HQ
|
idx_hq = tl.program_id(2)
|
||||||
idx_m = offs_m[:, None]
|
idx_m = offs_m[:, None]
|
||||||
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
|
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
|
||||||
|
|
||||||
|
|
@ -583,7 +583,7 @@ compute_flex_attention = r"""
|
||||||
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
||||||
|
|
||||||
if OUTPUT_LOGSUMEXP:
|
if OUTPUT_LOGSUMEXP:
|
||||||
off_hz = tl.program_id(1)
|
off_hz = off_zq * HQ + off_hq
|
||||||
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
||||||
lse = m_i + tl.math.log2(l_i)
|
lse = m_i + tl.math.log2(l_i)
|
||||||
if IS_DIVISIBLE:
|
if IS_DIVISIBLE:
|
||||||
|
|
@ -1572,6 +1572,7 @@ def flex_attention_backward_grid(
|
||||||
batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
|
batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
|
||||||
):
|
):
|
||||||
"""How is this kernel parallelized?
|
"""How is this kernel parallelized?
|
||||||
|
We create a grid of (ceil_div(n_queries, query_block_size) * heads_ratio + ceil_div(n_kv, kv_block_size), batch_size, kv_heads)
|
||||||
Currently this is only parallelizing over batch* kv_heads, but we can, and want to
|
Currently this is only parallelizing over batch* kv_heads, but we can, and want to
|
||||||
parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
|
parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
|
||||||
To do this will either require atomic updates to some grad values or to have a two pass kernel design.
|
To do this will either require atomic updates to some grad values or to have a two pass kernel design.
|
||||||
|
|
@ -1581,8 +1582,8 @@ def flex_attention_backward_grid(
|
||||||
return (
|
return (
|
||||||
triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
|
triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
|
||||||
+ triton.cdiv(num_key_value, meta["BLOCK_N1"]),
|
+ triton.cdiv(num_key_value, meta["BLOCK_N1"]),
|
||||||
1,
|
batch_size,
|
||||||
batch_size * kv_heads,
|
kv_heads,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1647,9 +1648,8 @@ flex_attention_backward_template = TritonTemplate(
|
||||||
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
||||||
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
||||||
|
|
||||||
off_hz = tl.program_id(2)
|
off_zq = tl.program_id(1) # q batch idx
|
||||||
off_zq = off_hz // HKV # q batch idx
|
off_hkv = tl.program_id(2) # kv head idx
|
||||||
off_hkv = off_hz % HKV # kv head idx
|
|
||||||
off_zkv = off_zq % ZKV # kv batch idx
|
off_zkv = off_zq % ZKV # kv batch idx
|
||||||
|
|
||||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user