mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update FlexAttention with masking semantic (#133373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133373 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
e7929809f3
commit
629bd6f718
|
|
@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import (
|
||||||
_create_empty_block_mask,
|
_create_empty_block_mask,
|
||||||
_DEFAULT_SPARSE_BLOCK_SIZE,
|
_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||||
_identity,
|
_identity,
|
||||||
|
_score_mod_signature,
|
||||||
and_masks,
|
and_masks,
|
||||||
BlockMask,
|
BlockMask,
|
||||||
create_block_mask,
|
create_block_mask,
|
||||||
|
|
@ -212,8 +213,7 @@ class TestFlexAttention(InductorTestCase):
|
||||||
):
|
):
|
||||||
compiled_error = (golden_out - compiled_out).abs().mean()
|
compiled_error = (golden_out - compiled_out).abs().mean()
|
||||||
ref_error = (golden_out - ref_out).abs().mean()
|
ref_error = (golden_out - ref_out).abs().mean()
|
||||||
# TODO: Make this check stricter after updating eager SDPA masked_softmax semantics
|
if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any():
|
||||||
if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
|
|
||||||
self.assertTrue(False, "Output/Grad with NaN")
|
self.assertTrue(False, "Output/Grad with NaN")
|
||||||
if compiled_error > ref_error * fudge_factor:
|
if compiled_error > ref_error * fudge_factor:
|
||||||
name = tensor_name if tensor_name is not None else ""
|
name = tensor_name if tensor_name is not None else ""
|
||||||
|
|
@ -263,7 +263,7 @@ class TestFlexAttention(InductorTestCase):
|
||||||
|
|
||||||
def run_test(
|
def run_test(
|
||||||
self,
|
self,
|
||||||
score_mod: Callable,
|
score_mod: _score_mod_signature,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
Q_B: int = B,
|
Q_B: int = B,
|
||||||
Q_H: int = H,
|
Q_H: int = H,
|
||||||
|
|
@ -273,6 +273,7 @@ class TestFlexAttention(InductorTestCase):
|
||||||
KV_H: int = H,
|
KV_H: int = H,
|
||||||
KV_S: int = S,
|
KV_S: int = S,
|
||||||
KV_D: int = D,
|
KV_D: int = D,
|
||||||
|
block_mask: Optional[BlockMask] = None,
|
||||||
):
|
):
|
||||||
q = torch.randn(
|
q = torch.randn(
|
||||||
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
|
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
|
||||||
|
|
@ -285,7 +286,6 @@ class TestFlexAttention(InductorTestCase):
|
||||||
)
|
)
|
||||||
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
||||||
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
||||||
block_mask = None
|
|
||||||
sdpa_partial = create_attention(
|
sdpa_partial = create_attention(
|
||||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
||||||
)
|
)
|
||||||
|
|
@ -1437,7 +1437,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
def test_fully_masked_out_rows(self):
|
@common_utils.parametrize("compile", [True, False])
|
||||||
|
def test_fully_masked_out_rows_0_check(self, compile: bool):
|
||||||
# Ensure fully masked out rows won't cause NaNs.
|
# Ensure fully masked out rows won't cause NaNs.
|
||||||
query = torch.randn(
|
query = torch.randn(
|
||||||
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
||||||
|
|
@ -1448,7 +1449,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||||
value = torch.randn(
|
value = torch.randn(
|
||||||
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
||||||
)
|
)
|
||||||
do = torch.randn((B, H, S, D), dtype=torch.float32, device="cuda")
|
|
||||||
|
|
||||||
M = S // 2
|
M = S // 2
|
||||||
|
|
||||||
|
|
@ -1456,15 +1456,33 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||||
return q < M
|
return q < M
|
||||||
|
|
||||||
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
||||||
out = torch.compile(flex_attention, dynamic=False)(
|
|
||||||
query, key, value, block_mask=block_mask
|
|
||||||
)
|
|
||||||
# TODO: Switch to self.run_test_with_call after updating eager SDPA masked_softmax semantics
|
|
||||||
self.assertEqual(out[:, :, M:, :].sum(), 0)
|
|
||||||
|
|
||||||
out.backward(do)
|
flex = (
|
||||||
|
torch.compile(flex_attention, dynamic=False) if compile else flex_attention
|
||||||
|
)
|
||||||
|
out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True)
|
||||||
|
self.assertEqual(out[:, :, M:, :].sum(), 0)
|
||||||
|
self.assertTrue((lse[:, :, M:] == 0.0).all())
|
||||||
|
|
||||||
|
loss = out.sum() + lse.sum()
|
||||||
|
loss.backward()
|
||||||
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
|
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("compile", [True, False])
|
||||||
|
def test_fully_masked_out_rows(self, compile: bool):
|
||||||
|
M = S // 2
|
||||||
|
|
||||||
|
def mask_mod(b, h, q, kv):
|
||||||
|
return q < M
|
||||||
|
|
||||||
|
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
||||||
|
|
||||||
|
def noop_mod(score, b, h, q_idx, kv_idx):
|
||||||
|
return score
|
||||||
|
|
||||||
|
self.run_test(noop_mod, torch.float32, B, H, S, D, B, H, S, D, block_mask)
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
def test_comparison_vs_sdpa(self):
|
def test_comparison_vs_sdpa(self):
|
||||||
def causal(score, b, h, q_idx, kv_idx):
|
def causal(score, b, h, q_idx, kv_idx):
|
||||||
|
|
|
||||||
|
|
@ -284,15 +284,20 @@ class TestFlexDecoding(InductorTestCase):
|
||||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
||||||
)
|
)
|
||||||
compiled_sdpa = torch.compile(sdpa_partial)
|
compiled_sdpa = torch.compile(sdpa_partial)
|
||||||
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
|
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
|
||||||
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
|
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
|
||||||
compiled_out = compiled_sdpa(q, k, v)
|
compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True)
|
||||||
|
|
||||||
self._check_out(
|
self._check_out(
|
||||||
golden_out,
|
golden_out,
|
||||||
ref_out,
|
ref_out,
|
||||||
compiled_out,
|
compiled_out,
|
||||||
)
|
)
|
||||||
|
self._check_out(
|
||||||
|
gold_lse,
|
||||||
|
ref_lse,
|
||||||
|
compiled_lse,
|
||||||
|
)
|
||||||
|
|
||||||
def run_test_with_call(
|
def run_test_with_call(
|
||||||
self,
|
self,
|
||||||
|
|
@ -762,6 +767,38 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||||
|
|
||||||
self.run_test(bias_mod)
|
self.run_test(bias_mod)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
def test_fully_masked_out_rows_0_check_gqa(self):
|
||||||
|
# Ensure fully masked out rows won't cause NaNs.
|
||||||
|
query = torch.randn(
|
||||||
|
(B, Hq, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
||||||
|
)
|
||||||
|
key = torch.randn(
|
||||||
|
(B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
||||||
|
)
|
||||||
|
value = torch.randn(
|
||||||
|
(B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
M = S // 2
|
||||||
|
|
||||||
|
def mask_mod(b, h, q, kv):
|
||||||
|
return q < M
|
||||||
|
|
||||||
|
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
||||||
|
|
||||||
|
flex = torch.compile(flex_attention, dynamic=False)
|
||||||
|
|
||||||
|
out, lse = flex(
|
||||||
|
query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True
|
||||||
|
)
|
||||||
|
self.assertEqual(out[:, :, M:, :].sum(), 0)
|
||||||
|
self.assertTrue((lse[:, :, M:] == 0.0).all())
|
||||||
|
|
||||||
|
loss = out.sum() + lse.sum()
|
||||||
|
loss.backward()
|
||||||
|
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
def test_windowed_no_mask_vs_sdpa(self):
|
def test_windowed_no_mask_vs_sdpa(self):
|
||||||
score_mod = _generate_windowed(1000)
|
score_mod = _generate_windowed(1000)
|
||||||
|
|
|
||||||
|
|
@ -204,11 +204,12 @@ def math_attention(
|
||||||
mask_mod_other_buffers,
|
mask_mod_other_buffers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO Unconditionally return logsumexp for backwards
|
# Set fully masked rows' sumexp to 0.0
|
||||||
# if any(t.requires_grad for t in (query, key, value)):
|
|
||||||
logsumexp = post_mod_scores.logsumexp(dim=-1)
|
logsumexp = post_mod_scores.logsumexp(dim=-1)
|
||||||
|
masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)
|
||||||
|
logsumexp = torch.where(masked_rows, 0.0, logsumexp)
|
||||||
|
|
||||||
post_mod_scores = post_mod_scores.softmax(dim=-1)
|
post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)
|
||||||
|
|
||||||
return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
|
return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -302,8 +302,13 @@ compute_flex_attention = r"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Store output and logsumexp
|
# [Note] Handle fully masked out rows:
|
||||||
l_i = tl.where(l_i == 0, 1, l_i)
|
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
||||||
|
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
||||||
|
l_i = tl.where(l_i == 0.0, 1, l_i)
|
||||||
|
masked_out_rows = (m_i == float("-inf"))
|
||||||
|
m_i = tl.where(masked_out_rows, 0, m_i)
|
||||||
|
|
||||||
acc = acc / l_i[:, None]
|
acc = acc / l_i[:, None]
|
||||||
idx_z = tl.program_id(1) // HQ
|
idx_z = tl.program_id(1) // HQ
|
||||||
idx_hq = tl.program_id(1) % HQ
|
idx_hq = tl.program_id(1) % HQ
|
||||||
|
|
|
||||||
|
|
@ -524,11 +524,17 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
||||||
# Reduction
|
# Reduction
|
||||||
|
|
||||||
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
|
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
|
||||||
|
# See [Note] Handle fully masked out rows:
|
||||||
|
# g_M Is the global max among split kv blocks.
|
||||||
|
masked_rows = lowerings[aten.eq](g_M, -float("inf"))
|
||||||
|
g_M = lowerings[aten.where](masked_rows, 0.0, g_M)
|
||||||
adj_M = lowerings[aten.sub](buf_M, g_M)
|
adj_M = lowerings[aten.sub](buf_M, g_M)
|
||||||
alpha = lowerings[aten.exp2](adj_M)
|
alpha = lowerings[aten.exp2](adj_M)
|
||||||
|
|
||||||
buf_L = lowerings[aten.mul](buf_L, alpha)
|
buf_L = lowerings[aten.mul](buf_L, alpha)
|
||||||
g_L = lowerings[aten.sum](buf_L, axis=1)
|
g_L = lowerings[aten.sum](buf_L, axis=1)
|
||||||
|
masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1)
|
||||||
|
g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L)
|
||||||
logsumexp = lowerings[aten.log2](g_L)
|
logsumexp = lowerings[aten.log2](g_L)
|
||||||
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
|
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user