Make Q Indices optional (#157997)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157997
Approved by: https://github.com/BoyuanFeng, https://github.com/Chillee
This commit is contained in:
drisspg 2025-07-10 18:56:07 -07:00 committed by PyTorch MergeBot
parent 22f3347fd9
commit 8c928372b3
4 changed files with 173 additions and 58 deletions

View File

@ -4608,44 +4608,6 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
seq_lengths=(1, 1),
)
@supported_platform
@common_utils.parametrize("compile", [False, True])
def test_no_q_info(self, device, compile: bool):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
# manually set q_num_blocks and q_indices to None
block_mask.q_num_blocks = None
block_mask.q_indices = None
block_mask.full_q_num_blocks = None
block_mask.full_q_indices = None
mask_mod_sparse_flex = functools.partial(flex_attention, block_mask=block_mask)
if compile:
mask_mod_sparse_flex = torch.compile(
mask_mod_sparse_flex, backend="inductor"
)
inputs = [
torch.randn(
2,
2,
2048,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
for _ in range(3)
]
causal_mask_out = mask_mod_sparse_flex(*inputs)
sdpa_mask_out = torch.nn.functional.scaled_dot_product_attention(
*inputs, is_causal=True
)
torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=5e-3, rtol=0.0)
@supported_platform
def test_doc_mask_clamped_repro(self, device):
def _offsets_to_doc_ids_tensor(offsets):
@ -4800,6 +4762,146 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
@supported_platform
@common_utils.parametrize("full_indices", [False, True])
def test_from_kv_blocks_without_q_computation(self, device, full_indices: bool):
(
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
) = self.generate_test_inputs(full_indices, device=device)
block_mask = BlockMask.from_kv_blocks(
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
compute_q_blocks=False,
)
self.assertIsInstance(block_mask, BlockMask)
self.assertEqual(block_mask.kv_num_blocks, kv_num_blocks)
self.assertEqual(block_mask.kv_indices, kv_indices)
self.assertIsNone(block_mask.q_num_blocks)
self.assertIsNone(block_mask.q_indices)
self.assertIsNone(block_mask.full_q_num_blocks)
self.assertIsNone(block_mask.full_q_indices)
if full_indices:
self.assertEqual(block_mask.full_kv_num_blocks, full_kv_num_blocks)
self.assertEqual(block_mask.full_kv_indices, full_kv_indices)
else:
self.assertIsNone(block_mask.full_kv_num_blocks)
self.assertIsNone(block_mask.full_kv_indices)
@supported_platform
@skip_on_cpu
def test_backward_error_with_none_q_indices(self, device):
N_BLOCKS = 4
B, H, S, D = 1, 1, 128, 64
S_KV = N_BLOCKS * S
kv_num_blocks = torch.tensor([[[N_BLOCKS]]], dtype=torch.int32, device=device)
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
block_mask = BlockMask.from_kv_blocks(
kv_num_blocks, kv_indices, compute_q_blocks=False
)
q = torch.randn(
B, H, S, D, dtype=torch.float16, device=device, requires_grad=True
)
k = torch.randn(
B, H, S_KV, D, dtype=torch.float16, device=device, requires_grad=True
)
v = torch.randn(
B, H, S_KV, D, dtype=torch.float16, device=device, requires_grad=True
)
flex_compile = torch.compile(flex_attention, fullgraph=True)
with torch.no_grad():
out_no_grad = flex_compile(q, k, v, block_mask=block_mask)
self.assertEqual(out_no_grad.shape, (B, H, S, D))
# Forward pass with grad enabled should error immediately
with self.assertRaisesRegex(
RuntimeError,
"BlockMask q_indices is None. Backward pass requires q_indices to be computed. "
"Please create the BlockMask with compute_q_blocks=True",
):
flex_compile(q, k, v, block_mask=block_mask)
@supported_platform
@skip_on_cpu
def test_forward_pass_with_none_q_indices(self, device):
N_BLOCKS = 4
B, H, S, D = 1, 1, 128, 64
S_KV = N_BLOCKS * S
kv_num_blocks = torch.tensor([[[N_BLOCKS]]], dtype=torch.int32, device=device)
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
block_mask = BlockMask.from_kv_blocks(
kv_num_blocks, kv_indices, compute_q_blocks=False
)
q = torch.randn(
B,
H,
S,
D,
dtype=torch.float16,
device=device,
)
k = torch.randn(
B,
H,
S_KV,
D,
dtype=torch.float16,
device=device,
)
v = torch.randn(
B,
H,
S_KV,
D,
dtype=torch.float16,
device=device,
)
flex_compile = torch.compile(flex_attention, fullgraph=True)
out = flex_compile(q, k, v, block_mask=block_mask)
self.assertEqual(out.shape, (B, H, S, D))
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.dtype, torch.float16)
@supported_platform
def test_block_mask_operations_with_none_q_indices(self, device):
kv_num_blocks = torch.tensor([[[4]]], dtype=torch.int32, device=device)
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
block_mask = BlockMask.from_kv_blocks(
kv_num_blocks, kv_indices, compute_q_blocks=False
)
self.assertEqual(block_mask.shape, (1, 1, 128, 512))
self.assertEqual(block_mask.BLOCK_SIZE, (128, 128))
sliced_mask = block_mask[0]
self.assertEqual(sliced_mask.shape, (1, 128, 512))
self.assertIsNone(sliced_mask.q_indices)
self.assertIsNone(sliced_mask.q_num_blocks)
# Test device movement
if device != "cpu":
cpu_mask = block_mask.to("cpu")
self.assertEqual(cpu_mask.kv_num_blocks.device.type, "cpu")
self.assertIsNone(cpu_mask.q_indices)
@large_tensor_test_class("2GB", device="cuda")
class TestPagedAttention(InductorTestCase):

View File

@ -134,6 +134,7 @@ class FlexAttentionBackwardHOP(HigherOrderOperator):
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
return super().__call__(
query,
key,
@ -770,6 +771,11 @@ def flex_attention_autograd(
for t in (query, key, value, *score_mod_other_buffers)
)
if torch.is_grad_enabled() and input_requires_grad:
if block_mask[7] is None:
raise RuntimeError(
"BlockMask q_indices is None. Backward pass requires q_indices to be computed. "
"Please create the BlockMask with compute_q_blocks=True"
)
example_vals = (
query.new_zeros((), requires_grad=input_requires_grad),
query.new_zeros((), dtype=torch.int),

View File

@ -1455,17 +1455,6 @@ def flex_attention(
num_consumer_groups, num_buffers_warp_spec = 0, 0
for conf in configs:
if (
SPARSE_KV_BLOCK_SIZE % conf.block_n != 0
or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0
):
if len(configs) == 1:
raise ValueError(
f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We "
f"got Q_BLOCK_SIZE={SPARSE_Q_BLOCK_SIZE} and KV_BLOCK_SIZE={SPARSE_KV_BLOCK_SIZE}."
)
continue
cur_kernel_options = original_kernel_options.copy()
# Performance tuning
# Triton parameters
@ -1493,6 +1482,20 @@ def flex_attention(
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
if (
cur_kernel_options["SPARSE_KV_BLOCK_SIZE"] % cur_kernel_options["BLOCK_N"]
!= 0
or cur_kernel_options["SPARSE_Q_BLOCK_SIZE"] % cur_kernel_options["BLOCK_M"]
!= 0
):
if len(configs) == 1:
raise ValueError(
f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We "
f"got Q_BLOCK_SIZE={cur_kernel_options['SPARSE_Q_BLOCK_SIZE']} and "
f"KV_BLOCK_SIZE={cur_kernel_options['SPARSE_KV_BLOCK_SIZE']}."
)
continue
# ROCm specific kernargs
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
if hasattr(conf, attrib):

View File

@ -292,8 +292,6 @@ class BlockMask:
raise RuntimeError("BlockMask must have at least 2 dimensions")
assert kv_num_blocks is not None, "kv_num_blocks must be provided"
assert kv_indices is not None, "kv_indices must be provided"
assert q_num_blocks is not None, "q_num_blocks must be provided"
assert q_indices is not None, "q_indices must be provided"
assert (full_kv_num_blocks is None) == (full_kv_indices is None), (
"full_kv_num_blocks and full_kv_indices must be both provided or omitted"
)
@ -323,6 +321,7 @@ class BlockMask:
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
mask_mod: Optional[_mask_mod_signature] = None,
seq_lengths: Optional[tuple[int, int]] = None,
compute_q_blocks: bool = True,
):
"""
Creates a BlockMask instance from key-value block information.
@ -350,13 +349,17 @@ class BlockMask:
)
# Generate q_num_blocks and q_indices
q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
if full_kv_num_blocks is not None:
assert full_kv_indices is not None
full_q_num_blocks, full_q_indices = _transpose_ordered(
full_kv_num_blocks, full_kv_indices
)
if compute_q_blocks:
q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
if full_kv_num_blocks is not None:
assert full_kv_indices is not None
full_q_num_blocks, full_q_indices = _transpose_ordered(
full_kv_num_blocks, full_kv_indices
)
else:
full_q_num_blocks, full_q_indices = None, None
else:
q_num_blocks, q_indices = None, None
full_q_num_blocks, full_q_indices = None, None
if isinstance(BLOCK_SIZE, int):
@ -365,7 +368,7 @@ class BlockMask:
mask_mod = mask_mod if mask_mod is not None else noop_mask
if seq_lengths is None:
q_length = kv_indices.shape[-2] * BLOCK_SIZE[0]
kv_length = q_indices.shape[-2] * BLOCK_SIZE[1]
kv_length = kv_indices.shape[-1] * BLOCK_SIZE[1]
seq_lengths = (q_length, kv_length)
return cls(
@ -481,6 +484,7 @@ class BlockMask:
BLOCK_SIZE=self.BLOCK_SIZE,
mask_mod=None,
seq_lengths=self.seq_lengths,
compute_q_blocks=self.q_indices is not None,
)
def __repr__(self):