From 629bd6f7184da58e21ba1580efbc8ff0c24dad14 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 22 Aug 2024 10:28:52 -0700 Subject: [PATCH] Update FlexAttention with masking semantic (#133373) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133373 Approved by: https://github.com/yanboliang --- test/inductor/test_flex_attention.py | 42 +++++++++++++++------- test/inductor/test_flex_decoding.py | 43 +++++++++++++++++++++-- torch/_higher_order_ops/flex_attention.py | 7 ++-- torch/_inductor/kernel/flex_attention.py | 9 +++-- torch/_inductor/kernel/flex_decoding.py | 6 ++++ 5 files changed, 87 insertions(+), 20 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 428aca1a23a..82e404f2723 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import ( _create_empty_block_mask, _DEFAULT_SPARSE_BLOCK_SIZE, _identity, + _score_mod_signature, and_masks, BlockMask, create_block_mask, @@ -212,8 +213,7 @@ class TestFlexAttention(InductorTestCase): ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() - # TODO: Make this check stricter after updating eager SDPA masked_softmax semantics - if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any(): + if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any(): self.assertTrue(False, "Output/Grad with NaN") if compiled_error > ref_error * fudge_factor: name = tensor_name if tensor_name is not None else "" @@ -263,7 +263,7 @@ class TestFlexAttention(InductorTestCase): def run_test( self, - score_mod: Callable, + score_mod: _score_mod_signature, dtype: torch.dtype = torch.float16, Q_B: int = B, Q_H: int = H, @@ -273,6 +273,7 @@ class TestFlexAttention(InductorTestCase): KV_H: int = H, KV_S: int = S, KV_D: int = D, + block_mask: Optional[BlockMask] = None, ): q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True @@ -285,7 +286,6 @@ class TestFlexAttention(InductorTestCase): ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) - block_mask = None sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) @@ -1437,7 +1437,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): out.sum().backward() @supported_platform - def test_fully_masked_out_rows(self): + @common_utils.parametrize("compile", [True, False]) + def test_fully_masked_out_rows_0_check(self, compile: bool): # Ensure fully masked out rows won't cause NaNs. query = torch.randn( (B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True @@ -1448,7 +1449,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): value = torch.randn( (B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True ) - do = torch.randn((B, H, S, D), dtype=torch.float32, device="cuda") M = S // 2 @@ -1456,15 +1456,33 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): return q < M block_mask = create_block_mask(mask_mod, 1, 1, S, S) - out = torch.compile(flex_attention, dynamic=False)( - query, key, value, block_mask=block_mask - ) - # TODO: Switch to self.run_test_with_call after updating eager SDPA masked_softmax semantics - self.assertEqual(out[:, :, M:, :].sum(), 0) - out.backward(do) + flex = ( + torch.compile(flex_attention, dynamic=False) if compile else flex_attention + ) + out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True) + self.assertEqual(out[:, :, M:, :].sum(), 0) + self.assertTrue((lse[:, :, M:] == 0.0).all()) + + loss = out.sum() + lse.sum() + loss.backward() self.assertEqual(query.grad[:, :, M:, :].sum(), 0) + @supported_platform + @common_utils.parametrize("compile", [True, False]) + def test_fully_masked_out_rows(self, compile: bool): + M = S // 2 + + def mask_mod(b, h, q, kv): + return q < M + + block_mask = create_block_mask(mask_mod, 1, 1, S, S) + + def noop_mod(score, b, h, q_idx, kv_idx): + return score + + self.run_test(noop_mod, torch.float32, B, H, S, D, B, H, S, D, block_mask) + @supported_platform def test_comparison_vs_sdpa(self): def causal(score, b, h, q_idx, kv_idx): diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 78324421cc8..9a51a9504a9 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -284,15 +284,20 @@ class TestFlexDecoding(InductorTestCase): score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) compiled_sdpa = torch.compile(sdpa_partial) - golden_out = sdpa_partial(q_gold, k_gold, v_gold) - ref_out = sdpa_partial(q_ref, k_ref, v_ref) - compiled_out = compiled_sdpa(q, k, v) + golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True) + ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True) + compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True) self._check_out( golden_out, ref_out, compiled_out, ) + self._check_out( + gold_lse, + ref_lse, + compiled_lse, + ) def run_test_with_call( self, @@ -762,6 +767,38 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.run_test(bias_mod) + @supported_platform + def test_fully_masked_out_rows_0_check_gqa(self): + # Ensure fully masked out rows won't cause NaNs. + query = torch.randn( + (B, Hq, S, D), dtype=torch.float32, device="cuda", requires_grad=True + ) + key = torch.randn( + (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True + ) + value = torch.randn( + (B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True + ) + + M = S // 2 + + def mask_mod(b, h, q, kv): + return q < M + + block_mask = create_block_mask(mask_mod, 1, 1, S, S) + + flex = torch.compile(flex_attention, dynamic=False) + + out, lse = flex( + query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True + ) + self.assertEqual(out[:, :, M:, :].sum(), 0) + self.assertTrue((lse[:, :, M:] == 0.0).all()) + + loss = out.sum() + lse.sum() + loss.backward() + self.assertEqual(query.grad[:, :, M:, :].sum(), 0) + @supported_platform def test_windowed_no_mask_vs_sdpa(self): score_mod = _generate_windowed(1000) diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 4d76e132514..b651d7d9b93 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -204,11 +204,12 @@ def math_attention( mask_mod_other_buffers, ) - # TODO Unconditionally return logsumexp for backwards - # if any(t.requires_grad for t in (query, key, value)): + # Set fully masked rows' sumexp to 0.0 logsumexp = post_mod_scores.logsumexp(dim=-1) + masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1) + logsumexp = torch.where(masked_rows, 0.0, logsumexp) - post_mod_scores = post_mod_scores.softmax(dim=-1) + post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1) return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 8a1286081b4..05e9f1c768a 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -302,8 +302,13 @@ compute_flex_attention = r""" ) - # Store output and logsumexp - l_i = tl.where(l_i == 0, 1, l_i) + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + masked_out_rows = (m_i == float("-inf")) + m_i = tl.where(masked_out_rows, 0, m_i) + acc = acc / l_i[:, None] idx_z = tl.program_id(1) // HQ idx_hq = tl.program_id(1) % HQ diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 249b9b8f606..96e1dbc19cf 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -524,11 +524,17 @@ def create_flex_decoding_kernel(*args, **kwargs): # Reduction g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0] + # See [Note] Handle fully masked out rows: + # g_M Is the global max among split kv blocks. + masked_rows = lowerings[aten.eq](g_M, -float("inf")) + g_M = lowerings[aten.where](masked_rows, 0.0, g_M) adj_M = lowerings[aten.sub](buf_M, g_M) alpha = lowerings[aten.exp2](adj_M) buf_L = lowerings[aten.mul](buf_L, alpha) g_L = lowerings[aten.sum](buf_L, axis=1) + masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1) + g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L) logsumexp = lowerings[aten.log2](g_L) logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))