mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make Q Indices optional (#157997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157997 Approved by: https://github.com/BoyuanFeng, https://github.com/Chillee
This commit is contained in:
parent
22f3347fd9
commit
8c928372b3
|
|
@ -4608,44 +4608,6 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||||
seq_lengths=(1, 1),
|
seq_lengths=(1, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
@supported_platform
|
|
||||||
@common_utils.parametrize("compile", [False, True])
|
|
||||||
def test_no_q_info(self, device, compile: bool):
|
|
||||||
def causal_mask(b, h, q_idx, kv_idx):
|
|
||||||
return q_idx >= kv_idx
|
|
||||||
|
|
||||||
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
|
|
||||||
# manually set q_num_blocks and q_indices to None
|
|
||||||
block_mask.q_num_blocks = None
|
|
||||||
block_mask.q_indices = None
|
|
||||||
block_mask.full_q_num_blocks = None
|
|
||||||
block_mask.full_q_indices = None
|
|
||||||
|
|
||||||
mask_mod_sparse_flex = functools.partial(flex_attention, block_mask=block_mask)
|
|
||||||
if compile:
|
|
||||||
mask_mod_sparse_flex = torch.compile(
|
|
||||||
mask_mod_sparse_flex, backend="inductor"
|
|
||||||
)
|
|
||||||
inputs = [
|
|
||||||
torch.randn(
|
|
||||||
2,
|
|
||||||
2,
|
|
||||||
2048,
|
|
||||||
64,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
requires_grad=True,
|
|
||||||
)
|
|
||||||
for _ in range(3)
|
|
||||||
]
|
|
||||||
|
|
||||||
causal_mask_out = mask_mod_sparse_flex(*inputs)
|
|
||||||
sdpa_mask_out = torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
*inputs, is_causal=True
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=5e-3, rtol=0.0)
|
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
def test_doc_mask_clamped_repro(self, device):
|
def test_doc_mask_clamped_repro(self, device):
|
||||||
def _offsets_to_doc_ids_tensor(offsets):
|
def _offsets_to_doc_ids_tensor(offsets):
|
||||||
|
|
@ -4800,6 +4762,146 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||||
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
|
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
|
||||||
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
|
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("full_indices", [False, True])
|
||||||
|
def test_from_kv_blocks_without_q_computation(self, device, full_indices: bool):
|
||||||
|
(
|
||||||
|
kv_num_blocks,
|
||||||
|
kv_indices,
|
||||||
|
full_kv_num_blocks,
|
||||||
|
full_kv_indices,
|
||||||
|
) = self.generate_test_inputs(full_indices, device=device)
|
||||||
|
|
||||||
|
block_mask = BlockMask.from_kv_blocks(
|
||||||
|
kv_num_blocks,
|
||||||
|
kv_indices,
|
||||||
|
full_kv_num_blocks,
|
||||||
|
full_kv_indices,
|
||||||
|
compute_q_blocks=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(block_mask, BlockMask)
|
||||||
|
self.assertEqual(block_mask.kv_num_blocks, kv_num_blocks)
|
||||||
|
self.assertEqual(block_mask.kv_indices, kv_indices)
|
||||||
|
|
||||||
|
self.assertIsNone(block_mask.q_num_blocks)
|
||||||
|
self.assertIsNone(block_mask.q_indices)
|
||||||
|
self.assertIsNone(block_mask.full_q_num_blocks)
|
||||||
|
self.assertIsNone(block_mask.full_q_indices)
|
||||||
|
|
||||||
|
if full_indices:
|
||||||
|
self.assertEqual(block_mask.full_kv_num_blocks, full_kv_num_blocks)
|
||||||
|
self.assertEqual(block_mask.full_kv_indices, full_kv_indices)
|
||||||
|
else:
|
||||||
|
self.assertIsNone(block_mask.full_kv_num_blocks)
|
||||||
|
self.assertIsNone(block_mask.full_kv_indices)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@skip_on_cpu
|
||||||
|
def test_backward_error_with_none_q_indices(self, device):
|
||||||
|
N_BLOCKS = 4
|
||||||
|
B, H, S, D = 1, 1, 128, 64
|
||||||
|
S_KV = N_BLOCKS * S
|
||||||
|
|
||||||
|
kv_num_blocks = torch.tensor([[[N_BLOCKS]]], dtype=torch.int32, device=device)
|
||||||
|
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
block_mask = BlockMask.from_kv_blocks(
|
||||||
|
kv_num_blocks, kv_indices, compute_q_blocks=False
|
||||||
|
)
|
||||||
|
|
||||||
|
q = torch.randn(
|
||||||
|
B, H, S, D, dtype=torch.float16, device=device, requires_grad=True
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
B, H, S_KV, D, dtype=torch.float16, device=device, requires_grad=True
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
B, H, S_KV, D, dtype=torch.float16, device=device, requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
flex_compile = torch.compile(flex_attention, fullgraph=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_no_grad = flex_compile(q, k, v, block_mask=block_mask)
|
||||||
|
self.assertEqual(out_no_grad.shape, (B, H, S, D))
|
||||||
|
|
||||||
|
# Forward pass with grad enabled should error immediately
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"BlockMask q_indices is None. Backward pass requires q_indices to be computed. "
|
||||||
|
"Please create the BlockMask with compute_q_blocks=True",
|
||||||
|
):
|
||||||
|
flex_compile(q, k, v, block_mask=block_mask)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@skip_on_cpu
|
||||||
|
def test_forward_pass_with_none_q_indices(self, device):
|
||||||
|
N_BLOCKS = 4
|
||||||
|
B, H, S, D = 1, 1, 128, 64
|
||||||
|
S_KV = N_BLOCKS * S
|
||||||
|
|
||||||
|
kv_num_blocks = torch.tensor([[[N_BLOCKS]]], dtype=torch.int32, device=device)
|
||||||
|
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
block_mask = BlockMask.from_kv_blocks(
|
||||||
|
kv_num_blocks, kv_indices, compute_q_blocks=False
|
||||||
|
)
|
||||||
|
|
||||||
|
q = torch.randn(
|
||||||
|
B,
|
||||||
|
H,
|
||||||
|
S,
|
||||||
|
D,
|
||||||
|
dtype=torch.float16,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
B,
|
||||||
|
H,
|
||||||
|
S_KV,
|
||||||
|
D,
|
||||||
|
dtype=torch.float16,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
B,
|
||||||
|
H,
|
||||||
|
S_KV,
|
||||||
|
D,
|
||||||
|
dtype=torch.float16,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
flex_compile = torch.compile(flex_attention, fullgraph=True)
|
||||||
|
out = flex_compile(q, k, v, block_mask=block_mask)
|
||||||
|
|
||||||
|
self.assertEqual(out.shape, (B, H, S, D))
|
||||||
|
self.assertIsInstance(out, torch.Tensor)
|
||||||
|
self.assertEqual(out.dtype, torch.float16)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
def test_block_mask_operations_with_none_q_indices(self, device):
|
||||||
|
kv_num_blocks = torch.tensor([[[4]]], dtype=torch.int32, device=device)
|
||||||
|
kv_indices = torch.tensor([[[[0, 1, 2, 3]]]], dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
block_mask = BlockMask.from_kv_blocks(
|
||||||
|
kv_num_blocks, kv_indices, compute_q_blocks=False
|
||||||
|
)
|
||||||
|
self.assertEqual(block_mask.shape, (1, 1, 128, 512))
|
||||||
|
self.assertEqual(block_mask.BLOCK_SIZE, (128, 128))
|
||||||
|
|
||||||
|
sliced_mask = block_mask[0]
|
||||||
|
self.assertEqual(sliced_mask.shape, (1, 128, 512))
|
||||||
|
self.assertIsNone(sliced_mask.q_indices)
|
||||||
|
self.assertIsNone(sliced_mask.q_num_blocks)
|
||||||
|
|
||||||
|
# Test device movement
|
||||||
|
if device != "cpu":
|
||||||
|
cpu_mask = block_mask.to("cpu")
|
||||||
|
self.assertEqual(cpu_mask.kv_num_blocks.device.type, "cpu")
|
||||||
|
self.assertIsNone(cpu_mask.q_indices)
|
||||||
|
|
||||||
|
|
||||||
@large_tensor_test_class("2GB", device="cuda")
|
@large_tensor_test_class("2GB", device="cuda")
|
||||||
class TestPagedAttention(InductorTestCase):
|
class TestPagedAttention(InductorTestCase):
|
||||||
|
|
|
||||||
|
|
@ -134,6 +134,7 @@ class FlexAttentionBackwardHOP(HigherOrderOperator):
|
||||||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||||
]:
|
]:
|
||||||
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
||||||
|
|
||||||
return super().__call__(
|
return super().__call__(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|
@ -770,6 +771,11 @@ def flex_attention_autograd(
|
||||||
for t in (query, key, value, *score_mod_other_buffers)
|
for t in (query, key, value, *score_mod_other_buffers)
|
||||||
)
|
)
|
||||||
if torch.is_grad_enabled() and input_requires_grad:
|
if torch.is_grad_enabled() and input_requires_grad:
|
||||||
|
if block_mask[7] is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"BlockMask q_indices is None. Backward pass requires q_indices to be computed. "
|
||||||
|
"Please create the BlockMask with compute_q_blocks=True"
|
||||||
|
)
|
||||||
example_vals = (
|
example_vals = (
|
||||||
query.new_zeros((), requires_grad=input_requires_grad),
|
query.new_zeros((), requires_grad=input_requires_grad),
|
||||||
query.new_zeros((), dtype=torch.int),
|
query.new_zeros((), dtype=torch.int),
|
||||||
|
|
|
||||||
|
|
@ -1455,17 +1455,6 @@ def flex_attention(
|
||||||
num_consumer_groups, num_buffers_warp_spec = 0, 0
|
num_consumer_groups, num_buffers_warp_spec = 0, 0
|
||||||
|
|
||||||
for conf in configs:
|
for conf in configs:
|
||||||
if (
|
|
||||||
SPARSE_KV_BLOCK_SIZE % conf.block_n != 0
|
|
||||||
or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0
|
|
||||||
):
|
|
||||||
if len(configs) == 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We "
|
|
||||||
f"got Q_BLOCK_SIZE={SPARSE_Q_BLOCK_SIZE} and KV_BLOCK_SIZE={SPARSE_KV_BLOCK_SIZE}."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cur_kernel_options = original_kernel_options.copy()
|
cur_kernel_options = original_kernel_options.copy()
|
||||||
# Performance tuning
|
# Performance tuning
|
||||||
# Triton parameters
|
# Triton parameters
|
||||||
|
|
@ -1493,6 +1482,20 @@ def flex_attention(
|
||||||
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
||||||
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cur_kernel_options["SPARSE_KV_BLOCK_SIZE"] % cur_kernel_options["BLOCK_N"]
|
||||||
|
!= 0
|
||||||
|
or cur_kernel_options["SPARSE_Q_BLOCK_SIZE"] % cur_kernel_options["BLOCK_M"]
|
||||||
|
!= 0
|
||||||
|
):
|
||||||
|
if len(configs) == 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We "
|
||||||
|
f"got Q_BLOCK_SIZE={cur_kernel_options['SPARSE_Q_BLOCK_SIZE']} and "
|
||||||
|
f"KV_BLOCK_SIZE={cur_kernel_options['SPARSE_KV_BLOCK_SIZE']}."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# ROCm specific kernargs
|
# ROCm specific kernargs
|
||||||
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
|
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
|
||||||
if hasattr(conf, attrib):
|
if hasattr(conf, attrib):
|
||||||
|
|
|
||||||
|
|
@ -292,8 +292,6 @@ class BlockMask:
|
||||||
raise RuntimeError("BlockMask must have at least 2 dimensions")
|
raise RuntimeError("BlockMask must have at least 2 dimensions")
|
||||||
assert kv_num_blocks is not None, "kv_num_blocks must be provided"
|
assert kv_num_blocks is not None, "kv_num_blocks must be provided"
|
||||||
assert kv_indices is not None, "kv_indices must be provided"
|
assert kv_indices is not None, "kv_indices must be provided"
|
||||||
assert q_num_blocks is not None, "q_num_blocks must be provided"
|
|
||||||
assert q_indices is not None, "q_indices must be provided"
|
|
||||||
assert (full_kv_num_blocks is None) == (full_kv_indices is None), (
|
assert (full_kv_num_blocks is None) == (full_kv_indices is None), (
|
||||||
"full_kv_num_blocks and full_kv_indices must be both provided or omitted"
|
"full_kv_num_blocks and full_kv_indices must be both provided or omitted"
|
||||||
)
|
)
|
||||||
|
|
@ -323,6 +321,7 @@ class BlockMask:
|
||||||
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||||
mask_mod: Optional[_mask_mod_signature] = None,
|
mask_mod: Optional[_mask_mod_signature] = None,
|
||||||
seq_lengths: Optional[tuple[int, int]] = None,
|
seq_lengths: Optional[tuple[int, int]] = None,
|
||||||
|
compute_q_blocks: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a BlockMask instance from key-value block information.
|
Creates a BlockMask instance from key-value block information.
|
||||||
|
|
@ -350,13 +349,17 @@ class BlockMask:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate q_num_blocks and q_indices
|
# Generate q_num_blocks and q_indices
|
||||||
q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
|
if compute_q_blocks:
|
||||||
if full_kv_num_blocks is not None:
|
q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
|
||||||
assert full_kv_indices is not None
|
if full_kv_num_blocks is not None:
|
||||||
full_q_num_blocks, full_q_indices = _transpose_ordered(
|
assert full_kv_indices is not None
|
||||||
full_kv_num_blocks, full_kv_indices
|
full_q_num_blocks, full_q_indices = _transpose_ordered(
|
||||||
)
|
full_kv_num_blocks, full_kv_indices
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
full_q_num_blocks, full_q_indices = None, None
|
||||||
else:
|
else:
|
||||||
|
q_num_blocks, q_indices = None, None
|
||||||
full_q_num_blocks, full_q_indices = None, None
|
full_q_num_blocks, full_q_indices = None, None
|
||||||
|
|
||||||
if isinstance(BLOCK_SIZE, int):
|
if isinstance(BLOCK_SIZE, int):
|
||||||
|
|
@ -365,7 +368,7 @@ class BlockMask:
|
||||||
mask_mod = mask_mod if mask_mod is not None else noop_mask
|
mask_mod = mask_mod if mask_mod is not None else noop_mask
|
||||||
if seq_lengths is None:
|
if seq_lengths is None:
|
||||||
q_length = kv_indices.shape[-2] * BLOCK_SIZE[0]
|
q_length = kv_indices.shape[-2] * BLOCK_SIZE[0]
|
||||||
kv_length = q_indices.shape[-2] * BLOCK_SIZE[1]
|
kv_length = kv_indices.shape[-1] * BLOCK_SIZE[1]
|
||||||
seq_lengths = (q_length, kv_length)
|
seq_lengths = (q_length, kv_length)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
|
@ -481,6 +484,7 @@ class BlockMask:
|
||||||
BLOCK_SIZE=self.BLOCK_SIZE,
|
BLOCK_SIZE=self.BLOCK_SIZE,
|
||||||
mask_mod=None,
|
mask_mod=None,
|
||||||
seq_lengths=self.seq_lengths,
|
seq_lengths=self.seq_lengths,
|
||||||
|
compute_q_blocks=self.q_indices is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user