mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make Q Indices optional (#157997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157997 Approved by: https://github.com/BoyuanFeng, https://github.com/Chillee
This commit is contained in:
parent
22f3347fd9
commit
8c928372b3
|
|
@ -4608,44 +4608,6 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
|||
seq_lengths=(1, 1),
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("compile", [False, True])
|
||||
def test_no_q_info(self, device, compile: bool):
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
|
||||
# manually set q_num_blocks and q_indices to None
|
||||
block_mask.q_num_blocks = None
|
||||
block_mask.q_indices = None
|
||||
block_mask.full_q_num_blocks = None
|
||||
block_mask.full_q_indices = None
|
||||
|
||||
mask_mod_sparse_flex = functools.partial(flex_attention, block_mask=block_mask)
|
||||
if compile:
|
||||
mask_mod_sparse_flex = torch.compile(
|
||||
mask_mod_sparse_flex, backend="inductor"
|
||||
)
|
||||
inputs = [
|
||||
torch.randn(
|
||||
2,
|
||||
2,
|
||||
2048,
|
||||
64,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
requires_grad=True,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
causal_mask_out = mask_mod_sparse_flex(*inputs)
|
||||
sdpa_mask_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
*inputs, is_causal=True
|
||||
)
|
||||
|
||||
torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=5e-3, rtol=0.0)
|
||||
|
||||
@supported_platform
|
||||
def test_doc_mask_clamped_repro(self, device):
|
||||
def _offsets_to_doc_ids_tensor(offsets):
|
||||
|
|
@ -4800,6 +4762,146 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
|||
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
|
||||
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("full_indices", [False, True])
|
||||
def test_from_kv_blocks_without_q_computation(self, device, full_indices: bool):
|
||||
(
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
full_kv_indices,
|
||||
) = self.generate_test_inputs(full_indices, device=device)
|
||||
|
||||
block_mask = BlockMask.from_kv_blocks(
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
full_kv_indices,
|
||||
compute_q_blocks=False,
|
||||
)
|
||||
|
||||
self.assertIsInstance(block_mask, BlockMask)
|
||||
self.assertEqual(block_mask.kv_num_blocks, kv_num_blocks)
|
||||
self.assertEqual(block_mask.kv_indices, kv_indices)
|
||||
|
||||
self.assertIsNone(block_mask.q_num_blocks)
|
||||
self.assertIsNone(block_mask.q_indices)
|
||||
self.assertIsNone(block_mask.full_q_num_blocks)
|
||||
self.assertIsNone(block_mask.full_q_indices)
|
||||
|
||||
if full_indices:
|
||||
self.assertEqual(block_mask.full_kv_num_blocks, full_kv_num_blocks)
|
||||
self.assertEqual(block_mask.full_kv_indices, full_kv_indices)
|
||||
else:
|
||||
self.assertIsNone(block_mask.full_kv_num_blocks)
|
||||
self.assertIsNone(block_mask.full_kv_indices)
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_backward_error_with_none_q_indices(self, device):
|
||||
N_BLOCKS = 4
|
||||
B, H, S, D = 1, 1, 128, 64
|
||||
S_KV = N_BLOCKS * S
|
||||
|
||||
kv_num_blocks = torch.tensor([[[N_BLOCKS]]], dtype=torch.int32, device=device)
|
||||
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
|
||||
|
||||
block_mask = BlockMask.from_kv_blocks(
|
||||
kv_num_blocks, kv_indices, compute_q_blocks=False
|
||||
)
|
||||
|
||||
q = torch.randn(
|
||||
B, H, S, D, dtype=torch.float16, device=device, requires_grad=True
|
||||
)
|
||||
k = torch.randn(
|
||||
B, H, S_KV, D, dtype=torch.float16, device=device, requires_grad=True
|
||||
)
|
||||
v = torch.randn(
|
||||
B, H, S_KV, D, dtype=torch.float16, device=device, requires_grad=True
|
||||
)
|
||||
|
||||
flex_compile = torch.compile(flex_attention, fullgraph=True)
|
||||
|
||||
with torch.no_grad():
|
||||
out_no_grad = flex_compile(q, k, v, block_mask=block_mask)
|
||||
self.assertEqual(out_no_grad.shape, (B, H, S, D))
|
||||
|
||||
# Forward pass with grad enabled should error immediately
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"BlockMask q_indices is None. Backward pass requires q_indices to be computed. "
|
||||
"Please create the BlockMask with compute_q_blocks=True",
|
||||
):
|
||||
flex_compile(q, k, v, block_mask=block_mask)
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_forward_pass_with_none_q_indices(self, device):
|
||||
N_BLOCKS = 4
|
||||
B, H, S, D = 1, 1, 128, 64
|
||||
S_KV = N_BLOCKS * S
|
||||
|
||||
kv_num_blocks = torch.tensor([[[N_BLOCKS]]], dtype=torch.int32, device=device)
|
||||
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
|
||||
|
||||
block_mask = BlockMask.from_kv_blocks(
|
||||
kv_num_blocks, kv_indices, compute_q_blocks=False
|
||||
)
|
||||
|
||||
q = torch.randn(
|
||||
B,
|
||||
H,
|
||||
S,
|
||||
D,
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
k = torch.randn(
|
||||
B,
|
||||
H,
|
||||
S_KV,
|
||||
D,
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
v = torch.randn(
|
||||
B,
|
||||
H,
|
||||
S_KV,
|
||||
D,
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
flex_compile = torch.compile(flex_attention, fullgraph=True)
|
||||
out = flex_compile(q, k, v, block_mask=block_mask)
|
||||
|
||||
self.assertEqual(out.shape, (B, H, S, D))
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.dtype, torch.float16)
|
||||
|
||||
@supported_platform
|
||||
def test_block_mask_operations_with_none_q_indices(self, device):
|
||||
kv_num_blocks = torch.tensor([[[4]]], dtype=torch.int32, device=device)
|
||||
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
|
||||
|
||||
block_mask = BlockMask.from_kv_blocks(
|
||||
kv_num_blocks, kv_indices, compute_q_blocks=False
|
||||
)
|
||||
self.assertEqual(block_mask.shape, (1, 1, 128, 512))
|
||||
self.assertEqual(block_mask.BLOCK_SIZE, (128, 128))
|
||||
|
||||
sliced_mask = block_mask[0]
|
||||
self.assertEqual(sliced_mask.shape, (1, 128, 512))
|
||||
self.assertIsNone(sliced_mask.q_indices)
|
||||
self.assertIsNone(sliced_mask.q_num_blocks)
|
||||
|
||||
# Test device movement
|
||||
if device != "cpu":
|
||||
cpu_mask = block_mask.to("cpu")
|
||||
self.assertEqual(cpu_mask.kv_num_blocks.device.type, "cpu")
|
||||
self.assertIsNone(cpu_mask.q_indices)
|
||||
|
||||
|
||||
@large_tensor_test_class("2GB", device="cuda")
|
||||
class TestPagedAttention(InductorTestCase):
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ class FlexAttentionBackwardHOP(HigherOrderOperator):
|
|||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
||||
|
||||
return super().__call__(
|
||||
query,
|
||||
key,
|
||||
|
|
@ -770,6 +771,11 @@ def flex_attention_autograd(
|
|||
for t in (query, key, value, *score_mod_other_buffers)
|
||||
)
|
||||
if torch.is_grad_enabled() and input_requires_grad:
|
||||
if block_mask[7] is None:
|
||||
raise RuntimeError(
|
||||
"BlockMask q_indices is None. Backward pass requires q_indices to be computed. "
|
||||
"Please create the BlockMask with compute_q_blocks=True"
|
||||
)
|
||||
example_vals = (
|
||||
query.new_zeros((), requires_grad=input_requires_grad),
|
||||
query.new_zeros((), dtype=torch.int),
|
||||
|
|
|
|||
|
|
@ -1455,17 +1455,6 @@ def flex_attention(
|
|||
num_consumer_groups, num_buffers_warp_spec = 0, 0
|
||||
|
||||
for conf in configs:
|
||||
if (
|
||||
SPARSE_KV_BLOCK_SIZE % conf.block_n != 0
|
||||
or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0
|
||||
):
|
||||
if len(configs) == 1:
|
||||
raise ValueError(
|
||||
f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We "
|
||||
f"got Q_BLOCK_SIZE={SPARSE_Q_BLOCK_SIZE} and KV_BLOCK_SIZE={SPARSE_KV_BLOCK_SIZE}."
|
||||
)
|
||||
continue
|
||||
|
||||
cur_kernel_options = original_kernel_options.copy()
|
||||
# Performance tuning
|
||||
# Triton parameters
|
||||
|
|
@ -1493,6 +1482,20 @@ def flex_attention(
|
|||
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
||||
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
if (
|
||||
cur_kernel_options["SPARSE_KV_BLOCK_SIZE"] % cur_kernel_options["BLOCK_N"]
|
||||
!= 0
|
||||
or cur_kernel_options["SPARSE_Q_BLOCK_SIZE"] % cur_kernel_options["BLOCK_M"]
|
||||
!= 0
|
||||
):
|
||||
if len(configs) == 1:
|
||||
raise ValueError(
|
||||
f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We "
|
||||
f"got Q_BLOCK_SIZE={cur_kernel_options['SPARSE_Q_BLOCK_SIZE']} and "
|
||||
f"KV_BLOCK_SIZE={cur_kernel_options['SPARSE_KV_BLOCK_SIZE']}."
|
||||
)
|
||||
continue
|
||||
|
||||
# ROCm specific kernargs
|
||||
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
|
||||
if hasattr(conf, attrib):
|
||||
|
|
|
|||
|
|
@ -292,8 +292,6 @@ class BlockMask:
|
|||
raise RuntimeError("BlockMask must have at least 2 dimensions")
|
||||
assert kv_num_blocks is not None, "kv_num_blocks must be provided"
|
||||
assert kv_indices is not None, "kv_indices must be provided"
|
||||
assert q_num_blocks is not None, "q_num_blocks must be provided"
|
||||
assert q_indices is not None, "q_indices must be provided"
|
||||
assert (full_kv_num_blocks is None) == (full_kv_indices is None), (
|
||||
"full_kv_num_blocks and full_kv_indices must be both provided or omitted"
|
||||
)
|
||||
|
|
@ -323,6 +321,7 @@ class BlockMask:
|
|||
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
mask_mod: Optional[_mask_mod_signature] = None,
|
||||
seq_lengths: Optional[tuple[int, int]] = None,
|
||||
compute_q_blocks: bool = True,
|
||||
):
|
||||
"""
|
||||
Creates a BlockMask instance from key-value block information.
|
||||
|
|
@ -350,13 +349,17 @@ class BlockMask:
|
|||
)
|
||||
|
||||
# Generate q_num_blocks and q_indices
|
||||
q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
|
||||
if full_kv_num_blocks is not None:
|
||||
assert full_kv_indices is not None
|
||||
full_q_num_blocks, full_q_indices = _transpose_ordered(
|
||||
full_kv_num_blocks, full_kv_indices
|
||||
)
|
||||
if compute_q_blocks:
|
||||
q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
|
||||
if full_kv_num_blocks is not None:
|
||||
assert full_kv_indices is not None
|
||||
full_q_num_blocks, full_q_indices = _transpose_ordered(
|
||||
full_kv_num_blocks, full_kv_indices
|
||||
)
|
||||
else:
|
||||
full_q_num_blocks, full_q_indices = None, None
|
||||
else:
|
||||
q_num_blocks, q_indices = None, None
|
||||
full_q_num_blocks, full_q_indices = None, None
|
||||
|
||||
if isinstance(BLOCK_SIZE, int):
|
||||
|
|
@ -365,7 +368,7 @@ class BlockMask:
|
|||
mask_mod = mask_mod if mask_mod is not None else noop_mask
|
||||
if seq_lengths is None:
|
||||
q_length = kv_indices.shape[-2] * BLOCK_SIZE[0]
|
||||
kv_length = q_indices.shape[-2] * BLOCK_SIZE[1]
|
||||
kv_length = kv_indices.shape[-1] * BLOCK_SIZE[1]
|
||||
seq_lengths = (q_length, kv_length)
|
||||
|
||||
return cls(
|
||||
|
|
@ -481,6 +484,7 @@ class BlockMask:
|
|||
BLOCK_SIZE=self.BLOCK_SIZE,
|
||||
mask_mod=None,
|
||||
seq_lengths=self.seq_lengths,
|
||||
compute_q_blocks=self.q_indices is not None,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user