mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)"
This reverts commit 795f28ac55.
Reverted https://github.com/pytorch/pytorch/pull/141625 on behalf of https://github.com/albanD due to Broken main ([comment](https://github.com/pytorch/pytorch/pull/141625#issuecomment-2511639687))
This commit is contained in:
parent
ec96597e47
commit
a34a56f69f
|
|
@ -668,64 +668,62 @@ class TestFlexAttention(InductorTestCase):
|
|||
D: int = D,
|
||||
):
|
||||
score_mod, mask_mod = score_mask_mod
|
||||
|
||||
# First batch with original dimensions (B, H, S, D)
|
||||
block_mask1 = create_block_mask(mask_mod, 1, 1, S, S)
|
||||
sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1)
|
||||
|
||||
# If the seqlen becomes smaller than the seqlen of the previous batch,
|
||||
# we can still reuse the block_mask created from a larger seqlen.
|
||||
MAX_S = S
|
||||
block_mask = create_block_mask(mask_mod, 1, 1, MAX_S, MAX_S)
|
||||
sdpa_partial = create_attention(score_mod, block_mask=block_mask)
|
||||
# The first eager batch, shape (B, H, S, D)
|
||||
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1)
|
||||
q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64)
|
||||
ref_out1 = sdpa_partial1(q1_ref, k1_ref, v1_ref)
|
||||
golden_out1 = sdpa_partial1(q1_gold, k1_gold, v1_gold)
|
||||
ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref)
|
||||
golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold)
|
||||
|
||||
backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
|
||||
golden_out1.backward(backward_grad1.to(torch.float64))
|
||||
ref_out1.backward(backward_grad1)
|
||||
|
||||
# Second batch with modified dimensions (B * 2, H, S / 2, D)
|
||||
# The second eager batch, shape (B * 2, H, S / 2, D)
|
||||
B = int(B * 2)
|
||||
S = int(S / 2)
|
||||
block_mask2 = create_block_mask(mask_mod, 1, 1, S, S)
|
||||
sdpa_partial2 = create_attention(score_mod, block_mask=block_mask2)
|
||||
|
||||
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2)
|
||||
q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64)
|
||||
ref_out2 = sdpa_partial2(q2_ref, k2_ref, v2_ref)
|
||||
golden_out2 = sdpa_partial2(q2_gold, k2_gold, v2_gold)
|
||||
ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref)
|
||||
golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold)
|
||||
|
||||
backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
|
||||
golden_out2.backward(backward_grad2.to(torch.float64))
|
||||
ref_out2.backward(backward_grad2)
|
||||
|
||||
# Third batch with modified dimensions (B * 2, H, S / 4, D)
|
||||
# The third eager batch, shape (B * 2, H, S / 4, D)
|
||||
S = int(S / 2)
|
||||
block_mask3 = create_block_mask(mask_mod, 1, 1, S, S)
|
||||
sdpa_partial3 = create_attention(score_mod, block_mask=block_mask3)
|
||||
|
||||
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
q3_ref, k3_ref, v3_ref = query_key_value_clones(q3, k3, v3)
|
||||
q3_gold, k3_gold, v3_gold = query_key_value_clones(q3, k3, v3, torch.float64)
|
||||
ref_out3 = sdpa_partial3(q3_ref, k3_ref, v3_ref)
|
||||
golden_out3 = sdpa_partial3(q3_gold, k3_gold, v3_gold)
|
||||
ref_out3 = sdpa_partial(q3_ref, k3_ref, v3_ref)
|
||||
golden_out3 = sdpa_partial(q3_gold, k3_gold, v3_gold)
|
||||
|
||||
backward_grad3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
|
||||
golden_out3.backward(backward_grad3.to(torch.float64))
|
||||
ref_out3.backward(backward_grad3)
|
||||
|
||||
# Clear dynamo counters
|
||||
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
|
||||
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
|
||||
torch._dynamo.reset()
|
||||
|
||||
# First compilation with original dimensions
|
||||
compiled_sdpa1 = torch.compile(sdpa_partial1, dynamic=True)
|
||||
compiled_out1 = compiled_sdpa1(q1, k1, v1)
|
||||
# Compiling with dynamic shape in the first batch.
|
||||
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
|
||||
compiled_out1 = compiled_sdpa(q1, k1, v1)
|
||||
compiled_out1.backward(backward_grad1)
|
||||
|
||||
self._check_out_and_grad(
|
||||
|
|
@ -744,11 +742,10 @@ class TestFlexAttention(InductorTestCase):
|
|||
)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
||||
|
||||
# Second compilation with new dimensions
|
||||
compiled_sdpa2 = torch.compile(sdpa_partial2, dynamic=True)
|
||||
compiled_out2 = compiled_sdpa2(q2, k2, v2)
|
||||
# Since current q_seqlen (MAX_S/2) is smaller than the seqlen from block_mask (MAX_S),
|
||||
# recompile to include the BlockMask._adjust part.
|
||||
compiled_out2 = compiled_sdpa(q2, k2, v2)
|
||||
compiled_out2.backward(backward_grad2)
|
||||
|
||||
self._check_out_and_grad(
|
||||
golden_out2,
|
||||
ref_out2,
|
||||
|
|
@ -763,13 +760,13 @@ class TestFlexAttention(InductorTestCase):
|
|||
v2_ref,
|
||||
v2,
|
||||
)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||
|
||||
# Third compilation with new dimensions
|
||||
compiled_sdpa3 = torch.compile(sdpa_partial3, dynamic=True)
|
||||
compiled_out3 = compiled_sdpa3(q3, k3, v3)
|
||||
# No re-compilation, use the compiled dynamic shape version.
|
||||
# The current q_seqlen (MAX_S/4) is still smaller than the seqlen from block_mask (MAX_S),
|
||||
# we don't recompile since we can reuse the compiled graph, which already includes the BlockMask._adjust part.
|
||||
compiled_out3 = compiled_sdpa(q3, k3, v3)
|
||||
compiled_out3.backward(backward_grad3)
|
||||
|
||||
self._check_out_and_grad(
|
||||
golden_out3,
|
||||
ref_out3,
|
||||
|
|
@ -784,7 +781,18 @@ class TestFlexAttention(InductorTestCase):
|
|||
v3_ref,
|
||||
v3,
|
||||
)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||
|
||||
# The forth iteration, shape (B * 2, H, S * 2, D)
|
||||
# Since seqlen is larger than the seqlen in block_mask, throw errors.
|
||||
S = int(S * 8)
|
||||
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.BackendCompilerFailed, "Q seqlen must be smaller than"
|
||||
):
|
||||
compiled_sdpa(q3, k3, v3)
|
||||
|
||||
def run_automatic_dynamic_test(
|
||||
self,
|
||||
|
|
@ -796,42 +804,38 @@ class TestFlexAttention(InductorTestCase):
|
|||
D: int = D,
|
||||
):
|
||||
MAX_S = S
|
||||
block_mask1 = create_block_mask(noop_mask, 1, 1, S, S)
|
||||
sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1)
|
||||
block_mask = create_block_mask(noop_mask, 1, 1, MAX_S, MAX_S)
|
||||
sdpa_partial = create_attention(score_mod, block_mask=block_mask)
|
||||
# The first eager batch, shape (B, H, S, D)
|
||||
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out1 = sdpa_partial1(
|
||||
golden_out1 = sdpa_partial(
|
||||
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
|
||||
)
|
||||
ref_out1 = sdpa_partial1(q1, k1, v1)
|
||||
ref_out1 = sdpa_partial(q1, k1, v1)
|
||||
|
||||
# The second eager batch, shape (B * 2, H, S / 2, D)
|
||||
B = int(B * 2)
|
||||
S = int(S / 2)
|
||||
block_mask2 = create_block_mask(noop_mask, 1, 1, S, S)
|
||||
sdpa_partial2 = create_attention(score_mod, block_mask=block_mask2)
|
||||
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out2 = sdpa_partial2(
|
||||
golden_out2 = sdpa_partial(
|
||||
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
|
||||
)
|
||||
ref_out2 = sdpa_partial2(q2, k2, v2)
|
||||
ref_out2 = sdpa_partial(q2, k2, v2)
|
||||
|
||||
# The third eager batch, shape (B * 4, H, S / 4, D)
|
||||
B = int(B * 2)
|
||||
S = int(S / 2)
|
||||
block_mask3 = create_block_mask(noop_mask, 1, 1, S, S)
|
||||
sdpa_partial3 = create_attention(score_mod, block_mask=block_mask3)
|
||||
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out3 = sdpa_partial3(
|
||||
golden_out3 = sdpa_partial(
|
||||
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64)
|
||||
)
|
||||
ref_out3 = sdpa_partial3(q3, k3, v3)
|
||||
ref_out3 = sdpa_partial(q3, k3, v3)
|
||||
|
||||
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
|
||||
# We check dynamo counters["frames"]["ok"] to ensure:
|
||||
|
|
@ -848,17 +852,18 @@ class TestFlexAttention(InductorTestCase):
|
|||
fudge_factor = 1.1
|
||||
|
||||
# The first batch.
|
||||
compiled_out1 = torch.compile(sdpa_partial1)(q1, k1, v1)
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
compiled_out1 = compiled_sdpa(q1, k1, v1)
|
||||
self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
||||
|
||||
# The second batch (automatic dynamic).
|
||||
compiled_out2 = torch.compile(sdpa_partial2)(q2, k2, v2)
|
||||
compiled_out2 = compiled_sdpa(q2, k2, v2)
|
||||
self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||
|
||||
# The third batch (no re-compilation).
|
||||
compiled_out3 = torch.compile(sdpa_partial3)(q3, k3, v3)
|
||||
compiled_out3 = compiled_sdpa(q3, k3, v3)
|
||||
self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||
|
||||
|
|
@ -907,6 +912,11 @@ class TestFlexAttention(InductorTestCase):
|
|||
def test_builtin_score_mods_dynamic(
|
||||
self, dtype: torch.dtype, score_mask_mod: Tuple[Callable, Callable]
|
||||
):
|
||||
if score_mask_mod[0].__name__ == "_alibi_bias":
|
||||
# TODO
|
||||
self.skipTest(
|
||||
"Alibi bias broken with dynamic shapes since we don't support capturing dynamic shapes"
|
||||
)
|
||||
self.run_dynamic_test(score_mask_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
|
|
@ -2193,7 +2203,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
|
||||
# Use weird mask to test reusing block_mask does work well.
|
||||
@supported_platform
|
||||
def _test_block_mask_reuse_with_weird_mask(self):
|
||||
def test_block_mask_reuse_with_weird_mask(self):
|
||||
def mask(b, h, q, kv):
|
||||
return (kv < 256) | (kv >= 2048)
|
||||
|
||||
|
|
@ -3221,12 +3231,12 @@ def forward(self, child : torch.Tensor, child_1 : torch.Tensor, child_2 : torch.
|
|||
norm_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_value_: "f64[2, 2, 128, 4]", L_block_mask_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_q_num_blocks: "i32[1, 1, 1]", L_block_mask_q_indices: "i32[1, 1, 1, 1]", L_block_mask_full_q_num_blocks: "i32[1, 1, 1]", L_block_mask_full_q_indices: "i32[1, 1, 1, 1]"):
|
||||
def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_value_: "f64[2, 2, 128, 4]", L_block_mask_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_full_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_q_num_blocks: "i32[1, 1, 1]", L_block_mask_q_indices: "i32[1, 1, 1, 1]", L_block_mask_full_q_num_blocks: "i32[1, 1, 1]", L_block_mask_full_q_indices: "i32[1, 1, 1, 1]"):
|
||||
l_query_ = L_query_
|
||||
l_key_ = L_key_
|
||||
l_value_ = L_value_
|
||||
l_block_mask_kv_indices = L_block_mask_kv_indices
|
||||
l_block_mask_kv_num_blocks = L_block_mask_kv_num_blocks
|
||||
l_block_mask_kv_indices = L_block_mask_kv_indices
|
||||
l_block_mask_full_kv_num_blocks = L_block_mask_full_kv_num_blocks
|
||||
l_block_mask_full_kv_indices = L_block_mask_full_kv_indices
|
||||
l_block_mask_q_num_blocks = L_block_mask_q_num_blocks
|
||||
|
|
@ -3236,7 +3246,7 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
score_mod_0 = self.score_mod_0
|
||||
mask_fn_0 = self.mask_fn_0
|
||||
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
|
||||
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
|
||||
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
|
||||
return (out,)
|
||||
|
||||
|
|
@ -3277,7 +3287,7 @@ class GraphModule(torch.nn.Module):
|
|||
fw_graph0 = self.fw_graph0
|
||||
joint_graph0 = self.joint_graph0
|
||||
mask_graph0 = self.mask_graph0
|
||||
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
|
||||
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
|
||||
getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
|
||||
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
|
||||
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
|
||||
|
|
@ -3633,7 +3643,6 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
|||
full_q_indices=None,
|
||||
BLOCK_SIZE=(64, 64),
|
||||
mask_mod=noop_mask,
|
||||
seq_lengths=(1, 1),
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
|
|
@ -3653,7 +3662,6 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
|||
full_q_indices=None, # Mismatched, should raise error
|
||||
BLOCK_SIZE=(64, 64),
|
||||
mask_mod=noop_mask,
|
||||
seq_lengths=(1, 1),
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
|
|
@ -3779,35 +3787,6 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
|||
block_mask = create_block_mask(doc_mask_mod, None, None, 1024 + i, 1024 + i)
|
||||
torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
|
||||
|
||||
@common_utils.parametrize("compile", [False, True])
|
||||
@supported_platform
|
||||
def test_block_mask_vs_sequence_lengths(self, compile):
|
||||
if compile:
|
||||
flex_attention_call = torch.compile(flex_attention)
|
||||
else:
|
||||
flex_attention_call = flex_attention
|
||||
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
def create_inputs(S):
|
||||
q, k, v = (
|
||||
torch.randn(
|
||||
1, 8, S, 64, dtype=torch.float16, requires_grad=True, device="cuda"
|
||||
)
|
||||
for _ in range(3)
|
||||
)
|
||||
return q, k, v
|
||||
|
||||
block_mask = create_block_mask(mask_mod, None, None, 1024, 1024)
|
||||
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
|
||||
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
|
||||
flex_attention_call(*create_inputs(2048), block_mask=block_mask)
|
||||
|
||||
block_mask = create_block_mask(mask_mod, None, None, 1023, 1023)
|
||||
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
|
||||
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
|
||||
|
||||
|
||||
class TestPagedAttention(InductorTestCase):
|
||||
def _check_equal(
|
||||
|
|
|
|||
|
|
@ -579,9 +579,9 @@ class TestFlexDecoding(InductorTestCase):
|
|||
ref_out = golden_call(q_ref, k_ref, v_ref)
|
||||
|
||||
if mask_mod is not None:
|
||||
block_mask = create_block_mask(mask_mod, Q_B, 1, Q_S, KV_S)
|
||||
block_mask = create_block_mask(mask_mod, Q_B, 1, 1, S)
|
||||
else:
|
||||
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S)
|
||||
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, S)
|
||||
|
||||
compiled_out, _ = self.run_paged_attention(
|
||||
score_mod, q, k, v, dtype, block_mask
|
||||
|
|
@ -682,7 +682,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
score_mod: Callable,
|
||||
BLOCK_SIZE: Union[int, Tuple[int, int]],
|
||||
):
|
||||
block_mask = create_block_mask(noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE)
|
||||
block_mask = create_block_mask(noop_mask, B, 1, S, S, BLOCK_SIZE=BLOCK_SIZE)
|
||||
self.run_test(score_mod, dtype, block_mask=block_mask)
|
||||
|
||||
def input_strides_1(B, H, S, D):
|
||||
|
|
@ -1098,7 +1098,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
def scoremod_2(qk, b, h, q, kv):
|
||||
return torch.where(q >= kv, qk, -float("inf"))
|
||||
|
||||
block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024)
|
||||
block_mask = create_block_mask(noop_mask, 1, 1, 1, S)
|
||||
|
||||
def f(q, k1, k2, v1, v2):
|
||||
q2 = flex_attention(q, k1, v1, score_mod=scoremod_1, block_mask=block_mask)
|
||||
|
|
@ -1167,7 +1167,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
def scoremod_2(qk, b, h, q, kv):
|
||||
return torch.where(q >= kv, qk, -float("inf"))
|
||||
|
||||
block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024)
|
||||
block_mask = create_block_mask(noop_mask, 1, 1, 1, S)
|
||||
|
||||
attention1 = functools.partial(
|
||||
flex_attention, score_mod=scoremod_1, block_mask=block_mask
|
||||
|
|
@ -1567,8 +1567,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
mask_mod=mask_mod,
|
||||
B=2,
|
||||
H=None,
|
||||
Q_LEN=2,
|
||||
KV_LEN=2,
|
||||
Q_LEN=128,
|
||||
KV_LEN=256,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -614,7 +614,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
*block_mask[:-1],
|
||||
*block_mask[:10],
|
||||
*score_mod_other_buffers,
|
||||
*mask_mod_other_buffers,
|
||||
),
|
||||
|
|
@ -630,8 +630,6 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
query_lengths,
|
||||
kv_lengths,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
|
|
@ -674,8 +672,6 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
fw_graph,
|
||||
joint_graph,
|
||||
(
|
||||
query_lengths,
|
||||
kv_lengths,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
|
|
@ -712,8 +708,7 @@ def flex_attention_autograd(
|
|||
|
||||
with TransformGetItemToIndex():
|
||||
input_requires_grad = any(
|
||||
isinstance(t, torch.Tensor) and t.requires_grad
|
||||
for t in (query, key, value, *score_mod_other_buffers)
|
||||
t.requires_grad for t in (query, key, value, *score_mod_other_buffers)
|
||||
)
|
||||
if torch.is_grad_enabled() and input_requires_grad:
|
||||
example_vals = (
|
||||
|
|
@ -1135,9 +1130,7 @@ def flex_attention_backward_fake_tensor_mode(
|
|||
grad_value = torch.empty_like(value)
|
||||
grad_score_mod_captured = tuple(
|
||||
[
|
||||
torch.empty_like(buffer)
|
||||
if isinstance(buffer, torch.Tensor) and buffer.requires_grad
|
||||
else None
|
||||
torch.empty_like(buffer) if buffer.requires_grad else None
|
||||
for buffer in score_mod_other_buffers
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -810,8 +810,6 @@ def flex_attention(
|
|||
mask_mod_other_buffers,
|
||||
):
|
||||
(
|
||||
_, # q_length
|
||||
_, # kv_length
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
|
|
@ -970,6 +968,12 @@ def flex_attention(
|
|||
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
|
||||
), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
|
||||
), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
|
||||
|
||||
# Note, we don't need to pass in the captured buffers explicitly
|
||||
# because they're implicitly added by the score_mod function
|
||||
|
|
@ -1505,7 +1509,7 @@ def bwd_dq_block_mn(
|
|||
) | indent_except_first(2) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
|
||||
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
|
||||
# apply mask for partial masked block
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -1537,7 +1541,7 @@ def bwd_dq_block_mn(
|
|||
|
||||
if not IS_FULL_BLOCKS:
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
|
||||
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
|
||||
# (grads) apply mask for partially unmasked block
|
||||
ds = tl.where(mask_mod_output, ds, 0.0)
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -1687,7 +1691,7 @@ def bwd_dkdv_block_mn(
|
|||
n="n",
|
||||
) | indent_except_first(2) }}
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
|
||||
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
|
||||
# (grads) apply mask for fully masked block
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -1745,7 +1749,7 @@ def bwd_dkdv_block_mn(
|
|||
dsT = grad_scores
|
||||
if not IS_FULL_BLOCKS:
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
|
||||
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
|
||||
# (grads) apply mask for partially unmasked block
|
||||
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -1856,8 +1860,6 @@ def flex_attention_backward(*args, **kwargs):
|
|||
mask_mod_other_buffers,
|
||||
) = args
|
||||
(
|
||||
_, # q_length
|
||||
_, # kv_length
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
|
|
@ -2034,9 +2036,6 @@ def flex_attention_backward(*args, **kwargs):
|
|||
or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
|
||||
):
|
||||
continue
|
||||
if num_warps == 8:
|
||||
# Working around https://github.com/pytorch/pytorch/issues/141603
|
||||
continue
|
||||
|
||||
# Performance tuning
|
||||
cur_kernel_options = original_kernel_options.copy()
|
||||
|
|
|
|||
|
|
@ -332,8 +332,6 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
mask_mod_other_buffers,
|
||||
) = args
|
||||
(
|
||||
_, # q_length
|
||||
_, # kv_length
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks, # full_kv_num_blocks,
|
||||
|
|
|
|||
|
|
@ -264,7 +264,6 @@ class PagedAttention:
|
|||
|
||||
new_mask_mod = self.get_mask_mod(block_mask.mask_mod)
|
||||
|
||||
seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
|
||||
return BlockMask.from_kv_blocks(
|
||||
new_kv_num_blocks,
|
||||
new_kv_indices,
|
||||
|
|
@ -272,7 +271,6 @@ class PagedAttention:
|
|||
new_full_kv_indices,
|
||||
block_mask.BLOCK_SIZE,
|
||||
new_mask_mod,
|
||||
seq_lengths=seq_lengths,
|
||||
)
|
||||
|
||||
def get_mask_mod(
|
||||
|
|
|
|||
|
|
@ -262,7 +262,6 @@ class BlockMask:
|
|||
the backwards pass. These are autogenerated from 2.
|
||||
"""
|
||||
|
||||
seq_lengths: Tuple[int, int]
|
||||
kv_num_blocks: Tensor
|
||||
kv_indices: Tensor
|
||||
full_kv_num_blocks: Optional[Tensor]
|
||||
|
|
@ -276,7 +275,6 @@ class BlockMask:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
seq_lengths: Tuple[int, int],
|
||||
kv_num_blocks: Tensor,
|
||||
kv_indices: Tensor,
|
||||
full_kv_num_blocks: Optional[Tensor],
|
||||
|
|
@ -301,7 +299,6 @@ class BlockMask:
|
|||
full_q_indices is None
|
||||
), "full_q_num_blocks and full_q_indices must be both provided or omitted"
|
||||
|
||||
self.seq_lengths = seq_lengths
|
||||
self.kv_num_blocks = kv_num_blocks
|
||||
self.kv_indices = kv_indices
|
||||
self.full_kv_num_blocks = full_kv_num_blocks
|
||||
|
|
@ -322,7 +319,6 @@ class BlockMask:
|
|||
full_kv_indices: Optional[Tensor] = None,
|
||||
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
mask_mod: Optional[_mask_mod_signature] = None,
|
||||
seq_lengths: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
"""
|
||||
Creates a BlockMask instance from key-value block information.
|
||||
|
|
@ -363,13 +359,8 @@ class BlockMask:
|
|||
BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)
|
||||
|
||||
mask_mod = mask_mod if mask_mod is not None else noop_mask
|
||||
if seq_lengths is None:
|
||||
q_length = kv_indices.shape[-2] * BLOCK_SIZE[0]
|
||||
kv_length = q_indices.shape[-2] * BLOCK_SIZE[1]
|
||||
seq_lengths = (q_length, kv_length)
|
||||
|
||||
return cls(
|
||||
seq_lengths=seq_lengths,
|
||||
kv_num_blocks=kv_num_blocks,
|
||||
kv_indices=kv_indices,
|
||||
full_kv_num_blocks=full_kv_num_blocks,
|
||||
|
|
@ -389,15 +380,11 @@ class BlockMask:
|
|||
Args:
|
||||
flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
|
||||
"""
|
||||
if flatten:
|
||||
block_size = (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) # type: ignore[assignment]
|
||||
seq_lengths = (self.seq_lengths[0], self.seq_lengths[1]) # type: ignore[assignment]
|
||||
else:
|
||||
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
|
||||
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
|
||||
block_size = (
|
||||
(self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) if flatten else (self.BLOCK_SIZE,)
|
||||
)
|
||||
|
||||
return (
|
||||
*seq_lengths,
|
||||
self.kv_num_blocks,
|
||||
self.kv_indices,
|
||||
self.full_kv_num_blocks,
|
||||
|
|
@ -410,11 +397,6 @@ class BlockMask:
|
|||
self.mask_mod,
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
*batch_dims, _, _ = self.kv_indices.shape
|
||||
return tuple(batch_dims) + self.seq_lengths
|
||||
|
||||
def __str__(self):
|
||||
s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n"
|
||||
mask_str = self.to_string().strip()
|
||||
|
|
@ -475,7 +457,6 @@ class BlockMask:
|
|||
new_full_kv_indices,
|
||||
BLOCK_SIZE=self.BLOCK_SIZE,
|
||||
mask_mod=None,
|
||||
seq_lengths=self.seq_lengths,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
|
@ -528,6 +509,14 @@ class BlockMask:
|
|||
self.mask_mod,
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Returns the shape of the mask."""
|
||||
*batch_dims, q_length, _ = self.kv_indices.shape
|
||||
q_length = self.kv_indices.shape[-2] * self.BLOCK_SIZE[0]
|
||||
kv_length = self.kv_indices.shape[-1] * self.BLOCK_SIZE[1]
|
||||
return tuple(batch_dims + [q_length, kv_length])
|
||||
|
||||
def numel(self):
|
||||
"""Returns the number of elements (not accounting for sparsity) in the mask."""
|
||||
shape = self.shape
|
||||
|
|
@ -750,7 +739,6 @@ def _convert_block_mask_to_mask(
|
|||
def _create_sparse_block_from_block_mask(
|
||||
block_mask: Tuple[Tensor, Optional[Tensor]],
|
||||
mask_mod: Optional[Callable],
|
||||
seq_lengths: Tuple[int, int],
|
||||
Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
) -> BlockMask:
|
||||
|
|
@ -769,7 +757,6 @@ def _create_sparse_block_from_block_mask(
|
|||
full_bm[1],
|
||||
BLOCK_SIZE=(Q_BLOCK_SIZE, KV_BLOCK_SIZE),
|
||||
mask_mod=mask_mod,
|
||||
seq_lengths=seq_lengths,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -891,11 +878,7 @@ def create_block_mask(
|
|||
separate_full_blocks=True,
|
||||
)
|
||||
block_mask = _create_sparse_block_from_block_mask(
|
||||
(partial_block_mask, full_block_mask),
|
||||
mask_mod,
|
||||
(Q_LEN, KV_LEN),
|
||||
Q_BLOCK_SIZE,
|
||||
KV_BLOCK_SIZE,
|
||||
(partial_block_mask, full_block_mask), mask_mod, Q_BLOCK_SIZE, KV_BLOCK_SIZE
|
||||
)
|
||||
return block_mask
|
||||
|
||||
|
|
@ -911,7 +894,6 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
|
|||
kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
|
||||
kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
|
||||
BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE,
|
||||
seq_lengths=(1, 1),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1255,31 +1237,29 @@ def flex_attention(
|
|||
|
||||
if block_mask is None:
|
||||
block_mask = _create_empty_block_mask(query, key)
|
||||
|
||||
if (
|
||||
block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE
|
||||
and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE
|
||||
elif (
|
||||
not query.is_nested
|
||||
and (query.requires_grad or key.requires_grad or value.requires_grad)
|
||||
and (
|
||||
query.size(-2)
|
||||
< block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
|
||||
or key.size(-2) < block_mask.kv_indices.size(-1) * block_mask.BLOCK_SIZE[1]
|
||||
)
|
||||
):
|
||||
# This corresponds to the case where we essentially have a "no-op" block mask.
|
||||
pass
|
||||
else:
|
||||
block_mask_q_len = block_mask.shape[-2]
|
||||
block_mask_kv_len = block_mask.shape[-1]
|
||||
if query.size(-2) > block_mask_q_len or key.size(-2) > block_mask_kv_len:
|
||||
raise ValueError(
|
||||
f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. "
|
||||
"As the block mask was created for a smaller length than you're using it for, you likely need to create a new block mask."
|
||||
)
|
||||
elif (
|
||||
query.size(-2) < block_mask_q_len and key.size(-2) <= block_mask_kv_len
|
||||
) or (query.size(-2) <= block_mask_q_len and key.size(-2) < block_mask_kv_len):
|
||||
raise ValueError(
|
||||
f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. "
|
||||
"As the block mask was created for a larger length than you're using it for, you can either 1. create a new block mask with the correct length, or 2. 'adjust' the existing block mask to the correct length by calling block_mask._adjust(q_len, kv_len). This essentially 'crops' the block mask to the upper left corner, which does not work for all mask_mods!"
|
||||
)
|
||||
assert query.size(-2) == block_mask_q_len
|
||||
assert key.size(-2) == block_mask_kv_len
|
||||
|
||||
new_q_len = _round_up_to_multiple(query.size(-2), block_mask.BLOCK_SIZE[0])
|
||||
new_kv_len = _round_up_to_multiple(key.size(-2), block_mask.BLOCK_SIZE[1])
|
||||
block_mask = block_mask._adjust(new_q_len, new_kv_len)
|
||||
elif query.is_nested and (
|
||||
block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
|
||||
!= _round_up_to_multiple(
|
||||
query._values.size(query._ragged_idx - 1), block_mask.BLOCK_SIZE[0] # type: ignore[attr-defined]
|
||||
)
|
||||
):
|
||||
# TODO: Maybe we want to auto-adjust for this case as well?
|
||||
raise RuntimeError(
|
||||
f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input "
|
||||
f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined]
|
||||
)
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(query.size(-1))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user