mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FlexFlash] Wire up mask_mod + blockmask to flash impl (#166359)
I have some local changes that I need to push to flash first https://github.com/Dao-AILab/flash-attention/pull/1970 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166359 Approved by: https://github.com/v0i0
This commit is contained in:
parent
2699f5410b
commit
51667435f5
|
|
@ -122,16 +122,52 @@ def cuda_kernel_profiler(kernel_pattern="flash_attncute"):
|
|||
result["found"] = any(kernel_pattern in name for name in kernel_names)
|
||||
|
||||
|
||||
def flash_vs_triton(q, k, v, score_mod=None, rtol=5e-3, atol=5e-3):
|
||||
def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2):
|
||||
compiled_fn = torch.compile(flex_attention)
|
||||
|
||||
out_ref_fp32 = flex_attention(
|
||||
q.to(torch.float32),
|
||||
k.to(torch.float32),
|
||||
v.to(torch.float32),
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
).to(q.dtype)
|
||||
|
||||
out_flash = compiled_fn(
|
||||
q, k, v, score_mod=score_mod, kernel_options={"force_flash": True}
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
kernel_options={"force_flash": True},
|
||||
)
|
||||
out_no_flash = compiled_fn(
|
||||
q, k, v, score_mod=score_mod, kernel_options={"force_flash": False}
|
||||
out_triton = compiled_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
kernel_options={"force_flash": False},
|
||||
)
|
||||
torch.testing.assert_close(out_flash, out_no_flash, rtol=rtol, atol=atol)
|
||||
return out_flash, out_no_flash
|
||||
|
||||
assert out_flash.shape == out_ref_fp32.shape == out_triton.shape
|
||||
assert not torch.isnan(out_flash).any()
|
||||
assert not torch.isnan(out_triton).any()
|
||||
assert not torch.isnan(out_ref_fp32).any()
|
||||
assert torch.isfinite(out_flash).all()
|
||||
assert torch.isfinite(out_triton).all()
|
||||
assert torch.isfinite(out_ref_fp32).all()
|
||||
|
||||
fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()
|
||||
|
||||
triton_error = (out_triton - out_ref_fp32).abs().max().item()
|
||||
flash_error = (out_flash - out_ref_fp32).abs().max().item()
|
||||
|
||||
assert flash_error <= rtol * triton_error + fwd_atol, (
|
||||
f"Flash error {flash_error:.2e} exceeds {rtol}x Triton error {triton_error:.2e} + {fwd_atol:.2e}"
|
||||
)
|
||||
|
||||
return out_flash, out_triton, out_ref_fp32
|
||||
|
||||
|
||||
def name_fn(score_mod):
|
||||
|
|
@ -162,26 +198,6 @@ class TestFlexFlash(InductorTestCase):
|
|||
q, k, v = create_test_tensors(seq_len=seq_len, dtype=dtype, device=device)
|
||||
flash_vs_triton(q, k, v, score_mod=_causal)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_force_flash_error_with_block_mask(self, device, dtype):
|
||||
"""Test that force_flash=True raises error when BlockMask is provided."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
# Create a causal block mask
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device)
|
||||
|
||||
compiled_fn = torch.compile(flex_attention)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"force_flash=True but flash attention cannot be used.*BlockMask.*not supported",
|
||||
):
|
||||
compiled_fn(
|
||||
q, k, v, block_mask=block_mask, kernel_options={"force_flash": True}
|
||||
)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_kernel_called(self, device, dtype):
|
||||
"""Test that flash attention kernel is actually called when force_flash=True."""
|
||||
|
|
@ -257,7 +273,6 @@ class TestFlexFlash(InductorTestCase):
|
|||
"""Test that force_flash=True raises error when tensor requires gradients."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
# Create a score mod with requires_grad tensor
|
||||
bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
def score_mod_with_grad(score, b, h, q_idx, kv_idx):
|
||||
|
|
@ -276,6 +291,108 @@ class TestFlexFlash(InductorTestCase):
|
|||
kernel_options={"force_flash": True},
|
||||
)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_block_mask(self, device, dtype):
|
||||
"""Test flash attention with block mask and mask_mod."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device)
|
||||
flash_vs_triton(q, k, v, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_block_mask_with_score_mod(self, device, dtype):
|
||||
"""Test flash attention with both block mask and score_mod."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device)
|
||||
flash_vs_triton(q, k, v, score_mod=_times_two, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_mask_mod_buffer(self, device, dtype):
|
||||
"""Test flash attention with mask_mod that loads from buffer."""
|
||||
q, k, v = create_test_tensors(
|
||||
batch_size=2, num_heads=4, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
mask_bias = torch.randn(4, device=device, dtype=dtype) * 0.1
|
||||
|
||||
def custom_mask(b, h, q_idx, kv_idx):
|
||||
bias_value = mask_bias[h]
|
||||
return (q_idx >= kv_idx) | (bias_value > 0)
|
||||
|
||||
block_mask = create_block_mask(custom_mask, 2, 4, 512, 512, device=device)
|
||||
flash_vs_triton(q, k, v, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_mask_mod_with_dual_buffers(self, device, dtype):
|
||||
"""Mask modifier should support multiple captured buffers."""
|
||||
batch_size, num_heads, seq_len = 2, 4, 512
|
||||
q, k, v = create_test_tensors(
|
||||
batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2
|
||||
batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.2
|
||||
|
||||
def dual_buffer_mask(b, h, q_idx, kv_idx):
|
||||
head_term = head_bias[h]
|
||||
batch_term = batch_bias[b]
|
||||
causal = q_idx >= kv_idx
|
||||
bias_cond = (head_term + batch_term).to(torch.float32) > 0
|
||||
return causal | bias_cond
|
||||
|
||||
block_mask = create_block_mask(
|
||||
dual_buffer_mask, batch_size, num_heads, seq_len, seq_len, device=device
|
||||
)
|
||||
flash_vs_triton(q, k, v, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_score_mod_with_many_buffer_indexing(self, device, dtype):
|
||||
batch_size, num_heads, seq_len = 2, 4, 512
|
||||
q, k, v = create_test_tensors(
|
||||
batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.15
|
||||
query_scale = torch.randn(seq_len, device=device, dtype=dtype) * 0.05
|
||||
kv_scale = torch.randn(seq_len, device=device, dtype=dtype) * 0.05
|
||||
batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1
|
||||
|
||||
def complex_score(score, b, h, q_idx, kv_idx):
|
||||
head_term = head_bias[h]
|
||||
query_term = query_scale[q_idx]
|
||||
kv_term = kv_scale[kv_idx]
|
||||
batch_term = batch_bias[b]
|
||||
return score + head_term + query_term - kv_term + batch_term
|
||||
|
||||
flash_vs_triton(q, k, v, score_mod=complex_score)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_score_and_mask_buffers(self, device, dtype):
|
||||
"""Test flash attention with both score_mod and mask_mod using buffers."""
|
||||
q, k, v = create_test_tensors(
|
||||
batch_size=2, num_heads=4, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
score_bias = torch.randn(4, device=device, dtype=dtype) * 0.2
|
||||
mask_bias = torch.randn(4, device=device, dtype=dtype) * 0.1
|
||||
|
||||
def score_with_buffer(score, b, h, q_idx, kv_idx):
|
||||
return score + score_bias[h]
|
||||
|
||||
def mask_with_buffer(b, h, q_idx, kv_idx):
|
||||
bias_value = mask_bias[h]
|
||||
return (q_idx >= kv_idx) | (bias_value > 0)
|
||||
|
||||
block_mask = create_block_mask(mask_with_buffer, 2, 4, 512, 512, device=device)
|
||||
flash_vs_triton(q, k, v, score_mod=score_with_buffer, block_mask=block_mask)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFlexFlash, globals(), only_for="cuda")
|
||||
|
||||
|
|
|
|||
|
|
@ -65,6 +65,10 @@ class CuteDSLSubgraphInfo:
|
|||
body: IndentedBuffer
|
||||
template_mask: Optional[str] = None
|
||||
template_out: Optional[str] = None
|
||||
cse: Optional[CSE[Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.only_copy_if_non_none_fields = ("cse",)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
@ -191,10 +195,15 @@ class CuteDSLTemplateKernel(Kernel):
|
|||
body=IndentedBuffer(),
|
||||
template_mask=None,
|
||||
template_out=None,
|
||||
cse=None,
|
||||
)
|
||||
|
||||
subgraph = self.subgraph_bodies[body_name]
|
||||
for key, value in subgraph.to_dict().items():
|
||||
if value is None and key in getattr(
|
||||
subgraph, "only_copy_if_non_none_fields", ()
|
||||
):
|
||||
continue
|
||||
setattr(self, key, value)
|
||||
|
||||
try:
|
||||
|
|
@ -212,15 +221,17 @@ class CuteDSLTemplateKernel(Kernel):
|
|||
setattr(self, key, value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def create_subgraph_body(self, body_name: str):
|
||||
def create_subgraph_body(self, body_name: str, *, clear_cse: bool = False):
|
||||
"""Create a new subgraph body for template processing."""
|
||||
assert body_name not in self.subgraph_bodies, (
|
||||
f"Subgraph body '{body_name}' already exists"
|
||||
)
|
||||
new_cse = self.cse.clone() if clear_cse else None
|
||||
self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo(
|
||||
body=IndentedBuffer(),
|
||||
template_mask=None,
|
||||
template_out=None,
|
||||
cse=new_cse,
|
||||
)
|
||||
with self.set_subgraph_body(body_name):
|
||||
yield
|
||||
|
|
@ -294,7 +305,8 @@ class CuteDSLTemplateKernel(Kernel):
|
|||
|
||||
# Register the hook and return placeholder
|
||||
placeholder = "<UNPACK_BUFFERS>"
|
||||
assert placeholder not in self.render_hooks
|
||||
# TODO: I think double invoking is fine for this specific hook
|
||||
# assert placeholder not in self.render_hooks
|
||||
self.render_hooks[placeholder] = hook
|
||||
return placeholder
|
||||
|
||||
|
|
@ -330,7 +342,7 @@ class CuteDSLTemplateKernel(Kernel):
|
|||
while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies:
|
||||
num += 1
|
||||
|
||||
with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"):
|
||||
with self.create_subgraph_body(f"mod_{subgraph_number}_{num}", clear_cse=True):
|
||||
subgraph = self._get_subgraph(subgraph_number)
|
||||
modification_handler = ModificationWrapperCuteDSL(
|
||||
self, subgraph_number, fixed_inputs, mask
|
||||
|
|
@ -429,40 +441,20 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
|
|||
# val_frag[0] = tensor[index]
|
||||
# result = val_frag.load()
|
||||
|
||||
index_frag = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
"cute.make_fragment(1, cutlass.Int32)",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
index_frag = self.kernel.cse.newvar(dtype=torch.int32)
|
||||
self.kernel.body.writeline(
|
||||
f"{index_frag} = cute.make_fragment(1, cutlass.Int32)"
|
||||
)
|
||||
self.kernel.body.writeline(f"{index_frag}.store({index_str})")
|
||||
|
||||
val_frag = self.kernel.cse.newvar(dtype=var_dtype)
|
||||
self.kernel.body.writeline(
|
||||
f"{val_frag} = cute.make_fragment(1, {cute_dtype})"
|
||||
)
|
||||
|
||||
self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{index_frag}.store({index_str})",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
val_frag = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"cute.make_fragment(1, {cute_dtype})",
|
||||
dtype=var_dtype,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
index_var = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{index_frag}[0]",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{val_frag}[0] = ({var}[{index_var}])",
|
||||
dtype=var_dtype,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
index_var = self.kernel.cse.newvar(dtype=torch.int32)
|
||||
self.kernel.body.writeline(f"{index_var} = {index_frag}[0]")
|
||||
self.kernel.body.writeline(f"{val_frag}[0] = ({var}[{index_var}])")
|
||||
|
||||
final_expr = f"{val_frag}.load()"
|
||||
|
||||
|
|
|
|||
|
|
@ -193,24 +193,6 @@ def flex_attention(
|
|||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
if _use_flex_flash_attention(
|
||||
subgraph,
|
||||
mask_graph,
|
||||
kernel_options,
|
||||
num_score_mod_placeholders=len(placeholder_inps),
|
||||
):
|
||||
return create_flex_flash_attention_kernel(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask,
|
||||
scale,
|
||||
kernel_options,
|
||||
subgraph_buffer,
|
||||
mask_graph_buffer,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
(
|
||||
query,
|
||||
|
|
@ -240,6 +222,30 @@ def flex_attention(
|
|||
]
|
||||
)
|
||||
|
||||
if _use_flex_flash_attention(
|
||||
subgraph,
|
||||
mask_graph,
|
||||
kernel_options,
|
||||
num_score_mod_placeholders=len(placeholder_inps),
|
||||
):
|
||||
return create_flex_flash_attention_kernel(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask,
|
||||
scale,
|
||||
kernel_options,
|
||||
subgraph_buffer,
|
||||
mask_graph_buffer,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
full_kv_indices,
|
||||
mask_graph=mask_graph,
|
||||
)
|
||||
|
||||
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
|
||||
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,10 +56,8 @@ def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int):
|
|||
return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:])
|
||||
|
||||
|
||||
def is_trivial_graph(
|
||||
graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int
|
||||
):
|
||||
"""Check if the flex graphs are compatible with Flash Attention."""
|
||||
def is_trivial_mask_graph(graph_module: GraphModule) -> bool:
|
||||
"""Mask graph is trivial when it only gates via the default full op."""
|
||||
graph = graph_module.graph
|
||||
nodes = list(graph.nodes)
|
||||
placeholders = [n for n in nodes if n.op == "placeholder"]
|
||||
|
|
@ -67,14 +65,16 @@ def is_trivial_graph(
|
|||
assert len(output) == 1, "Got graph w/ multiple outputs"
|
||||
output_val = output[0].args[0]
|
||||
|
||||
if is_score_graph:
|
||||
if input_buffers_require_grads(graph_module, num_score_mod_placeholders):
|
||||
return False
|
||||
return True # party on garth
|
||||
# mask mod graph is empty if we have 4 inputs and full_default output
|
||||
return len(placeholders) == 4 and output_val.target is torch.ops.aten.full.default
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _supports_nontrivial_mask_graphs() -> bool:
|
||||
"""Currently only supported on Hopper (SM90) GPUs."""
|
||||
return torch.cuda.get_device_capability()[0] == 9
|
||||
|
||||
|
||||
def _can_use_flex_flash_attention(
|
||||
subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int
|
||||
) -> tuple[bool, str]:
|
||||
|
|
@ -91,32 +91,15 @@ def _can_use_flex_flash_attention(
|
|||
False,
|
||||
"Input buffers require gradients (not supported by flash attention)",
|
||||
)
|
||||
mask_trivial = is_trivial_mask_graph(mask_graph.graph_module)
|
||||
|
||||
score_trivial = is_trivial_graph(
|
||||
subgraph.graph_module,
|
||||
is_score_graph=True,
|
||||
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||
)
|
||||
mask_trivial = is_trivial_graph(
|
||||
mask_graph.graph_module,
|
||||
is_score_graph=False,
|
||||
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||
)
|
||||
if mask_trivial:
|
||||
return True, ""
|
||||
|
||||
if not score_trivial and not mask_trivial:
|
||||
if not _supports_nontrivial_mask_graphs():
|
||||
return (
|
||||
False,
|
||||
"Both score and mask graphs are too complex for flash attention (require simple operations only)",
|
||||
)
|
||||
elif not score_trivial:
|
||||
return (
|
||||
False,
|
||||
"Score modification captured tensors that require gradients (not supported by flash attention)",
|
||||
)
|
||||
elif not mask_trivial:
|
||||
return (
|
||||
False,
|
||||
"A non None BlockMask was passed to flex attention (not supported by flash attention yet)",
|
||||
"NYI: Non-trivial mask graphs only supported on Hopper (SM90) for flash attention",
|
||||
)
|
||||
|
||||
return True, ""
|
||||
|
|
@ -154,6 +137,11 @@ def create_flex_flash_attention_kernel(
|
|||
mask_graph_buffer: SubgraphResults,
|
||||
score_mod_other_buffers: list[TensorBox],
|
||||
mask_mod_other_buffers: list[TensorBox],
|
||||
kv_num_blocks: TensorBox | None,
|
||||
kv_indices: TensorBox | None,
|
||||
full_kv_num_blocks: TensorBox | None,
|
||||
full_kv_indices: TensorBox | None,
|
||||
mask_graph: Subgraph,
|
||||
) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox | ShapeAsConstantBuffer]:
|
||||
"""Create a flex flash attention kernel using CuteDSL template."""
|
||||
if not ensure_flash_available():
|
||||
|
|
@ -193,17 +181,34 @@ def create_flex_flash_attention_kernel(
|
|||
stride=[sympy.sympify(s) for s in output.get_stride()],
|
||||
)
|
||||
|
||||
# Used to check if we can skip block sparse impl
|
||||
mask_graph_is_trivial = is_trivial_mask_graph(mask_graph.graph_module)
|
||||
|
||||
needs_block_mask = not mask_graph_is_trivial
|
||||
has_full_blocks = full_kv_num_blocks is not None
|
||||
|
||||
choices: list[Any] = []
|
||||
causal = kernel_options.get("causal", False)
|
||||
assert flash_attention_cutedsl_template is not None
|
||||
|
||||
input_nodes = [query, key, value, lse]
|
||||
if has_full_blocks:
|
||||
input_nodes.extend(
|
||||
[kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices]
|
||||
)
|
||||
|
||||
if needs_block_mask and not has_full_blocks:
|
||||
raise NotImplementedError(
|
||||
"Flash attention with block mask but without full blocks is not supported yet"
|
||||
)
|
||||
|
||||
error = flash_attention_cutedsl_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=[query, key, value, lse],
|
||||
input_nodes=input_nodes,
|
||||
layout=output_layout,
|
||||
mutated_inputs=[lse],
|
||||
subgraphs=[subgraph_buffer, mask_graph_buffer],
|
||||
SM_SCALE=scale,
|
||||
CAUSAL=causal,
|
||||
NEEDS_BLOCK_MASK=needs_block_mask,
|
||||
)
|
||||
|
||||
if error or not choices:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
|
||||
{% if NEEDS_BLOCK_MASK %}
|
||||
{{def_kernel("Q", "K", "V", "LOGSUMEXP", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
||||
{% else %}
|
||||
{{def_kernel("Q", "K", "V", "LOGSUMEXP")}}
|
||||
{% endif %}
|
||||
from flash_attn.cute.interface import _flash_attn_fwd
|
||||
from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch
|
||||
|
||||
# Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D)
|
||||
q_transposed = Q.transpose(1, 2)
|
||||
|
|
@ -26,6 +30,25 @@
|
|||
output = {{get_output()}}
|
||||
output_transposed = output.transpose(1, 2)
|
||||
|
||||
{% if NEEDS_BLOCK_MASK %}
|
||||
@cute.jit
|
||||
def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors):
|
||||
{{unpack_buffers("aux_tensors", indent_width=8)}}
|
||||
{{ modification(
|
||||
subgraph_number=1,
|
||||
output_name="mask_mod_output",
|
||||
b="b_idx",
|
||||
h="h_idx",
|
||||
m="q_idx",
|
||||
n="kv_idx",
|
||||
) | indent_except_first(2) }}
|
||||
return mask_mod_output
|
||||
block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX)
|
||||
{% else %}
|
||||
block_sparse_tensors = None
|
||||
mask_mod = None
|
||||
{% endif %}
|
||||
|
||||
# Collect any additional tensor buffers that were added during modifications
|
||||
{% set tensor_buffers = get_tensor_buffers() -%}
|
||||
{% if tensor_buffers -%}
|
||||
|
|
@ -41,10 +64,11 @@
|
|||
k_transposed,
|
||||
v_transposed,
|
||||
softmax_scale={{SM_SCALE}},
|
||||
causal={{CAUSAL}},
|
||||
return_lse=True,
|
||||
score_mod=score_mod,
|
||||
mask_mod=mask_mod,
|
||||
out=output_transposed,
|
||||
lse=LOGSUMEXP,
|
||||
block_sparse_tensors=block_sparse_tensors,
|
||||
aux_tensors=buffers
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user