[Code Motion]Restructure flex attention kernel into flex subdirectory (#159437)

Mostly code motion, updating relative paths, moving some imports that had to be lazy before to top level scope now that we are free from the curse.

This will make it easier to add newer templates and provide some organization

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159437
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng, https://github.com/eellison, https://github.com/Skylion007
This commit is contained in:
drisspg 2025-07-30 08:40:12 -07:00 committed by PyTorch MergeBot
parent 4defea1e2c
commit e221a1c853
7 changed files with 1017 additions and 965 deletions

View File

@ -269,14 +269,6 @@
"torch/_inductor/kernel/conv.py": {
"def convolution()": 231
},
"torch/_inductor/kernel/flex_attention.py": {
"def flex_attention()": 303,
"def flex_attention_backward()": 323,
"def lower_cpu()": 273
},
"torch/_inductor/kernel/flex_decoding.py": {
"def create_flex_decoding_kernel()": 288
},
"torch/_inductor/kernel/mm.py": {
"def tuned_addmm()": 169,
"def tuned_mm()": 127,
@ -344,4 +336,4 @@
"torch/_inductor/wrapper_benchmark.py": {
"def parse_profile_event_list()": 119
}
}
}

View File

@ -1 +1 @@
from . import mm, mm_common, mm_plus_mm
from . import flex, mm, mm_common, mm_plus_mm

View File

@ -0,0 +1,3 @@
# mypy: allow-untyped-defs
# Import so here and then reimport above so that register_lowering gets triggered
from . import flex_attention, flex_decoding

View File

@ -0,0 +1,589 @@
# mypy: allow-untyped-defs
"""Common utilities and functions for flex attention kernels"""
import math
from collections.abc import Sequence
from typing import Any, Optional, Union
import sympy
import torch
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map
from ...ir import (
ComputedBuffer,
ExternKernel,
FixedLayout,
FlexibleLayout,
get_fill_order,
InputBuffer,
IRNode,
MutationLayoutSHOULDREMOVE,
Scatter,
ShapeAsConstantBuffer,
StorageBox,
Subgraph,
TensorBox,
)
from ...lowering import (
_full,
check_and_broadcast_indices,
expand,
index_output_size_and_inner_fn,
to_dtype,
)
from ...select_algorithm import realize_inputs
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
def zeros_and_scatter_lowering(shape: list[int], indices, values):
"""To support backwards on captured buffers we register a specific lowering for our specific custom up"""
# Always accumulate into fp32 then cast
grad = _full(0, values.get_device(), torch.float32, shape)
assert isinstance(grad, TensorBox)
grad.realize()
x_size = grad.get_size()
values = to_dtype(values, grad.get_dtype())
indices_loaders = [i.make_loader() if i is not None else None for i in indices]
indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device())
# We can use the first one since they are all required to be the same size
tensor_size = list(indices[tensor_indices[0]].get_size())
indexed_size = [x_size[i] for i in range(len(indices))]
expected_vals_size, inner_fn = index_output_size_and_inner_fn(
x_size,
indices,
tensor_indices,
tensor_size,
indices_loaders,
indexed_size,
None,
check=True,
)
values = expand(values, expected_vals_size)
device = grad.get_device()
assert device is not None
scatter = Scatter(
device=device,
dtype=grad.get_dtype(),
inner_fn=values.make_loader(),
ranges=expected_vals_size, # iter_ranges,
output_indexer=inner_fn,
scatter_mode="atomic_add",
)
buffer = ComputedBuffer(
name=grad.data.data.name, # type: ignore[attr-defined]
layout=MutationLayoutSHOULDREMOVE(grad),
data=scatter,
)
return buffer
def get_fwd_subgraph_outputs(
subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults
) -> list[Optional[ComputedBuffer]]:
subgraph_buffer = (
subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer]
)
mask_graph_buffer = (
mask_graph_buffer
if isinstance(mask_graph_buffer, Sequence)
else [mask_graph_buffer]
)
return [*subgraph_buffer, *mask_graph_buffer]
def build_subgraph_module_buffer(
args: list[Union[TensorBox, ShapeAsConstantBuffer]],
graph_module: torch.fx.GraphModule,
) -> SubgraphResults:
"""This function's goal is to take in the required args and produce the subgraph buffer
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
Args:
args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
subgraph: The Subgraph ir for which to produce the output node
"""
# This one we gotta keep lazy
from ...subgraph_lowering import PointwiseSubgraphLowering
pw_subgraph = PointwiseSubgraphLowering(
graph_module,
root_graph_lowering=V.graph,
allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]),
additional_lowerings={
torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering
},
)
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
pw_subgraph.run(*args)
# Since we are allowing mutations/buffer creation, we need to register any fresh buffers
# creating during the pointwise subgraph lowering
if len(pw_subgraph.buffers) > 0:
for buffer in pw_subgraph.buffers:
V.graph.register_buffer(buffer)
def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]:
if output_buffer is None:
return None
if isinstance(output_buffer, ComputedBuffer):
# These nodes are coming from the output of zeros_and_scatter
return output_buffer
assert isinstance(output_buffer, TensorBox), (
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
type(output_buffer),
)
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
device = output_buffer.data.get_device()
assert device is not None
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=device,
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data, # type: ignore[arg-type]
)
return subgraph_buffer
return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs)
def build_subgraph_buffer(
args: list[Union[TensorBox, ShapeAsConstantBuffer]], subgraph: Subgraph
) -> SubgraphResults:
return build_subgraph_module_buffer(args, subgraph.graph_module)
def maybe_realize(args: list[Optional[IRNode]]):
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
return tree_map(
lambda x: (
realize_inputs(x)
if x is not None and not isinstance(x, sympy.Symbol)
else x
),
args,
)
def create_placeholder(
name: str,
dtype: torch.dtype,
device: torch.device,
size: Optional[list[int]] = None,
) -> Union[TensorBox, ShapeAsConstantBuffer]:
"""Creates a placeholder input buffers for producing subgraph_output."""
input_buffer = InputBuffer(
name=name,
layout=FixedLayout(
device,
dtype,
size if size else [],
FlexibleLayout.contiguous_strides(size) if size else [],
),
)
return TensorBox.create(input_buffer)
def construct_strides(
sizes: Sequence[int],
fill_order: Sequence[int],
) -> Sequence[int]:
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
# Initialize strides
assert len(sizes) == len(fill_order), (
"Length of sizes must match the length of the fill order"
)
strides = [0] * len(sizes)
# Start with stride 1 for the innermost dimension
current_stride = 1
# Iterate through the fill order populating strides
for dim in fill_order:
strides[dim] = current_stride
current_stride *= sizes[dim]
return strides
def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]):
"""This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp
Args:
size: The size of the output tensor
orig_strides: The strides of the input tensor
Returns:
List[int]: Dense non-overlapping strides that preserve the input tensor's layout permutation.
The returned strides follow the same stride propagation rules as TensorIterator. This matches
The behavior of empty_like()
"""
fill_order = get_fill_order(orig_strides, V.graph.sizevars.shape_env)
return construct_strides(size, fill_order)
def create_indices_fake(x) -> torch.Tensor:
"""Create a fake indices that is used for autotuning."""
size = [V.graph.sizevars.size_hint(i) for i in x.get_size()]
indices = torch.arange(0, size[-1], dtype=x.get_dtype(), device=x.get_device())
indices = indices.expand(size).contiguous()
return indices
def create_num_blocks_fake_generator(sparse_indices):
"""Create a fake num_blocks that is used for autotuning.
The idea here is that we need to create a real tensor with real data
that's representative for benchmarking.
For example, returning all zeros for the `kv_num_blocks` input would mean
that we are computing 0 blocks for each row, which would provide bogus
autotuning results.
In this case, we choose to use min(16, max_block) blocks, because I
(Horace) think it'll probably result in pretty representative performance.
If it's too short then prefetching won't help. If it's too long then
autotuning will take longer for no good reason.
"""
def create_num_blocks_fake(x) -> torch.Tensor:
num_blocks_for_autotuning = V.graph.sizevars.size_hint(sparse_indices.shape[-1])
size = [V.graph.sizevars.size_hint(i) for i in x.get_size()]
return torch.full(
size,
num_blocks_for_autotuning,
dtype=x.get_dtype(),
device=x.get_device(),
)
return create_num_blocks_fake
def contiguous_last_dim(x):
"""Ensure that realized IR node has a contiguous stride in the last dimension."""
strides = x.maybe_get_stride()
if strides and strides[-1] != 1:
contiguous_stride_order = list(reversed(range(len(x.get_size()))))
return ExternKernel.require_stride_order(x, contiguous_stride_order)
return x
def set_head_dim_values(
kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars
):
"""
Mutates kernel options, adding head dimension calculations.
Args:
kernel_options: Dictionary to populate with options
qk_head_dim: Query/Key head dimension
v_head_dim: Value head dimension
graph_sizevars: Graph size variables object with guard_int method
"""
# QK dimensions
qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim)
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static)
kernel_options.setdefault(
"QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static)
)
# V dimensions
v_head_dim_static = graph_sizevars.guard_int(v_head_dim)
kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static)
kernel_options.setdefault(
"V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static)
)
# Safety flag
kernel_options.setdefault(
"SAFE_HEAD_DIM",
is_power_of_2(qk_head_dim_static) and is_power_of_2(v_head_dim_static),
)
def is_power_of_2(n):
return n != 0 and ((n & (n - 1)) == 0)
def next_power_of_two(n):
if n <= 0:
return 1
return 2 ** math.ceil(math.log2(n))
# ---- Common Template Strings ----
compute_forward_block_mn = r"""
@triton.jit
def forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
# -- load k --
# NB reversed order to since K is transposed
{%- if USE_TMA %}
k = tl.load_tensor_descriptor(
desc_k,
[kv_start + kv_offset, 0],
)
{%- else %}
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
{%- endif %}
if USE_TMA:
k = tl.trans(k)
# -- compute qk ---
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
# which is larger than the actual number of elements. To avoid access memory out of bound,
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
out="qk"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=1,
output_name="mask_mod_output",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
# apply mask for partially unmasked blocks
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# -- compute scaling constant ---
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
if not ROWS_GUARANTEED_SAFE:
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
else:
m_ij_masked = m_ij
alpha = tl.math.exp2(m_i - m_ij_masked)
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
# NB: l_i update is pulled up here since it's a bit faster
# NB: For headdim=256, it's faster to move it back down to after m_i =
# m_ij
l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc --
acc = acc * alpha[:, None]
{%- if USE_TMA %}
v = tl.load_tensor_descriptor(
desc_v,
[kv_start + kv_offset, 0],
)
{%- else %}
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
{%- endif %}
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
# -- update m_i
m_i = m_ij
return acc, l_i, m_i
"""
compute_forward_inner = r"""
@triton.jit
def forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets used as inputs to score_mod & mask_mod
# of size [BLOCK_M, BLOCK_N] or scalar.
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
# blocksparse data
kv_indices, kv_num_blocks,
# start kv and end kv block
block_n_start, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
RCP_LN2: tl.constexpr = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
kv_offset = 0
# loop over k, v and update accumulator until block_n_end
for start_n in range(block_n_start, block_n_end):
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
if IS_DIVISIBLE:
acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
else:
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
# it's on par or slightly faster than only applying to the last block in fwd.
# However, we choose different strategy for bwd, where we only apply mod & mask
# to the last block because it's faster a lot.
acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
offset = get_offset_for_next_block(
start_n, kv_indices, kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
)
offs_n = offs_n + offset
kv_offset += offset
if not USE_TMA:
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
return acc, l_i, m_i
"""
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
compute_next_offset_func = r"""
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
"""
get_bounded_indices_func = r"""
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
"""
load_checked_block = r"""
@triton.jit
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
if IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr)
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
else:
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
"""
load_checked_2d = r"""
@triton.jit
def load_checked_2d(
ptr,
offs_m,
offs_n,
stride_m,
stride_n,
IS_DIVISIBLE_M: tl.constexpr,
IS_DIVISIBLE_N: tl.constexpr,
M_LEN: tl.constexpr,
N_DIM: tl.constexpr,
):
# Calculate final pointer if strides are provided
if stride_m is not None and stride_n is not None:
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# Handle all masking cases
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
else: # Both divisible
return tl.load(ptr)
"""

View File

@ -0,0 +1,339 @@
# mypy: allow-untyped-defs
"""CPU-specific implementations for flex attention"""
import copy
import os
import sys
from typing import Any
import sympy
import torch
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import ValueRanges
from ...codegen.cpp_flex_attention_template import CppFlexAttentionTemplate
from ...ir import Buffer, FixedLayout, TensorBox
from ...select_algorithm import autotune_select_algorithm
from .common import (
build_subgraph_buffer,
build_subgraph_module_buffer,
contiguous_last_dim,
create_placeholder,
get_fwd_subgraph_outputs,
infer_dense_strides,
maybe_realize,
)
def check_cpu_supported():
requires_avx2_on_cpu = (
torch.cpu._is_avx2_supported() and os.getenv("ATEN_CPU_CAPABILITY") != "default"
)
supported = (
requires_avx2_on_cpu
and not torch.xpu.is_available()
and not sys.platform == "darwin"
)
return supported
def lower_cpu(
query,
key,
value,
subgraph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
):
"""CPP based template for flex attention for x86 CPUs"""
(
_, # q_length
_, # kv_length
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
SPARSE_Q_BLOCK_SIZE,
SPARSE_KV_BLOCK_SIZE,
mask_graph,
) = block_mask
if kernel_options["OUTPUT_LOGSUMEXP"]:
raise NotImplementedError(
"torch.compile on CPU only supports inference and `return_lse` is not supported yet."
)
if not check_cpu_supported():
raise NotImplementedError(
"torch.compile on current platform is not supported for CPU."
)
fake_buffers: list[Buffer] = [] # noqa: F821
# [Note] Handle the case where the split sizes are not statically known.
# The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime.
# We use symbols to represent them during the compilation here.
# They'll be replaced by the string "cur_qSplitSize" and "cur_kvSplitSize" in
# the modification function of the CppFlexAttentionTemplate class.
cur_qSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr
cur_kvSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr
shape_env = V.graph.sizevars.shape_env
# We don't know the concrete value of cur_qSplitSize and cur_kvSplitSize during the compilation.
# Mark symbols > 1 to ensure broadcasting is always applied.
# This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`.
shape_env.var_to_range[cur_qSplitSize] = ValueRanges(2, int_oo)
shape_env.var_to_range[cur_kvSplitSize] = ValueRanges(2, int_oo)
score_dtype = torch.float
placeholder_inps = [
create_placeholder(name, dtype, query.get_device(), size)
for name, dtype, size in [
("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]),
("b", torch.int64, []),
("h", torch.int64, []),
("q_idx", torch.int64, [cur_qSplitSize, 1]),
("kv_idx", torch.int64, [1, cur_kvSplitSize]),
]
]
subgraph_buffer = build_subgraph_buffer(
placeholder_inps + list(score_mod_other_buffers), subgraph
)
if subgraph_buffer is not None:
if isinstance(subgraph_buffer, list):
for _buf in subgraph_buffer:
if _buf is not None:
_buf.freeze_layout()
else:
subgraph_buffer.freeze_layout()
mask_graph_placeholder_inps = [
create_placeholder(name, dtype, query.get_device(), size)
for name, dtype, size in [
("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]),
("b", torch.int64, []),
("h", torch.int64, []),
("q_idx", torch.int64, [cur_qSplitSize, 1]),
("kv_idx", torch.int64, [1, cur_kvSplitSize]),
]
]
# The original mask_graph works on a scalar and only includes
# the logic of calculating the mask value.
# We need to add the logic of applying the mark to the qk_data tensor
# into the graph for the later codegen of this part.
# Example:
# mask_graph:
# def mask_fn(b, h, q_idx, kv_idx):
# mask = q_idx >= kv_idx
# return mask
# The converted_mask_graph should be:
# def converted_mask_fn(qk_data, b, h, q_idx, kv_idx):
# mask = q_idx >= kv_idx
# qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf")))
# return qk_data
def convert_mask_graph_module(mask_graph):
gm = copy.deepcopy(mask_graph.graph_module)
graph = gm.graph
# Add qk_data as the first input
with graph.inserting_before(next(iter(graph.nodes))):
qk_data_node = graph.placeholder("qk_data")
# Find the node that returns the mask
output_node = None
for node in graph.nodes:
if node.op == "output":
output_node = node
break
# Get the mask node
assert output_node is not None
mask_node = output_node.args[0]
size_node = [cur_qSplitSize, cur_kvSplitSize]
# Create a new node for torch.full
with graph.inserting_after(mask_node):
full_node = graph.call_function(
torch.full,
args=(size_node, -float("inf")),
kwargs={"dtype": score_dtype},
)
# Create a new node for torch.where
with graph.inserting_after(full_node):
where_node = graph.call_function(
torch.ops.aten.where, args=(mask_node, qk_data_node, full_node)
)
# Update the output node to return the result of torch.where
output_node.args = (where_node,)
graph.lint()
converted = torch.fx.GraphModule(gm, graph)
return converted
converted_mask_graph_module = convert_mask_graph_module(mask_graph)
mask_graph_buffer = build_subgraph_module_buffer(
mask_graph_placeholder_inps + list(mask_mod_other_buffers),
converted_mask_graph_module,
)
# Clear the pending fresh unbacked symbols that are created for cur_qSplitSize and cur_kvSplitSize in the current kernel.
pending = V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols
V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols = [
x for x in pending if x not in (cur_qSplitSize, cur_kvSplitSize)
]
buffer_list = (
placeholder_inps
+ list(score_mod_other_buffers)
+ mask_graph_placeholder_inps
+ list(mask_mod_other_buffers)
)
for item in buffer_list:
if isinstance(item, TensorBox):
fake_buffers.append(item.data.data) # type: ignore[attr-defined]
# CPU kernel requires last dim to be contiguous
query, key, value = map(contiguous_last_dim, [query, key, value])
(
query,
key,
value,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
) = maybe_realize(
[
query,
key,
value,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
]
)
if len(OrderedSet([query.get_name(), key.get_name(), value.get_name()])) != 3:
raise NotImplementedError(
"Unsupported for now if query, key, value are the same buffer."
)
if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]:
raise NotImplementedError(
"`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. "
f"Found input tensors are `{query.get_dtype()}`."
)
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
B = Bq
# Construct output layout with strides matching the query.
out_size = [B, Hq, seq_len_q, v_head_dim]
out_strides = infer_dense_strides(out_size, query.get_stride())
layout = FixedLayout(
query.get_device(),
query.get_dtype(),
[B, Hq, seq_len_q, v_head_dim],
stride=[sympy.sympify(s) for s in out_strides],
)
_choices: list[Any] = []
input_nodes = [query, key, value, kv_num_blocks, kv_indices]
if not full_kv_num_blocks:
no_full_kv_block = True
else:
no_full_kv_block = False
input_nodes += [full_kv_num_blocks]
input_nodes += [full_kv_indices]
has_other_buffer = False
kernel_input_name_to_buffer = {}
if score_mod_other_buffers or mask_mod_other_buffers:
has_other_buffer = True
for prefix, buffers in [
("score_others", score_mod_other_buffers),
("mask_others", mask_mod_other_buffers),
]:
kernel_input_name_to_buffer.update(
{f"{prefix}_{i}": buf for i, buf in enumerate(buffers)}
)
input_nodes += [
value
for value in kernel_input_name_to_buffer.values()
if not isinstance(value, sympy.Symbol)
]
skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False)
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
), (
"Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
), (
"KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
)
CppFlexAttentionTemplate.add_choices(
choices=_choices,
input_nodes=input_nodes,
layout=layout,
scale=scale,
score_mod=None if skip_mask_score else subgraph_buffer,
mask_mod=None if skip_mask_score else mask_graph_buffer,
kv_block_size=SPARSE_KV_BLOCK_SIZE,
q_block_size=SPARSE_Q_BLOCK_SIZE,
has_other_buffer=has_other_buffer,
no_full_kv_block=no_full_kv_block,
fake_buffers=fake_buffers,
len_score_other=len(score_mod_other_buffers),
len_mask_other=len(mask_mod_other_buffers),
kernel_input_name_to_buffer=kernel_input_name_to_buffer,
block_vars=(cur_qSplitSize, cur_kvSplitSize),
)
inputs_for_autotuning = [
query,
key,
value,
]
res = autotune_select_algorithm(
"flex_attention",
_choices,
inputs_for_autotuning,
layout,
)
# need subgraph inputs and outputs to analyze all symints used in flex attention
res.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
mask_mod_other_buffers
)
res.data.data.subgraph_outs = get_fwd_subgraph_outputs(
subgraph_buffer, mask_graph_buffer
)
return (res,)

View File

@ -8,12 +8,16 @@ import sympy
import torch
from torch._inductor.virtualized import V
from .. import ir
from ..ir import FixedLayout, FlexibleLayout
from ..lowering import empty, empty_strided, lowerings
from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
from ..select_algorithm import autotune_select_algorithm, SymbolicGridFn, TritonTemplate
from .flex_attention import (
from ... import ir
from ...ir import FixedLayout, FlexibleLayout
from ...lowering import empty, empty_strided, lowerings
from ...runtime.runtime_utils import is_power_of_2, next_power_of_2
from ...select_algorithm import (
autotune_select_algorithm,
SymbolicGridFn,
TritonTemplate,
)
from .common import (
compute_forward_block_mn,
compute_forward_inner,
compute_next_offset_func,
@ -24,6 +28,7 @@ from .flex_attention import (
load_checked_2d,
load_checked_block,
maybe_realize,
set_head_dim_values,
)
@ -31,6 +36,45 @@ aten = torch.ops.aten
prims = torch.ops.prims
def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa) -> bool:
"""Decide which kernel to use, return true if use flex decoding kernel.
Note:
Since the number of splits is calculated based of the the number of batch and head dims
we need to ensure that the batch and head dims are statically known. Otherwise we just
use the main flex_attention kernel.
"""
force_flex = kernel_options.get("FORCE_USE_FLEX_ATTENTION", False)
short_query_length = V.graph.sizevars.evaluate_expr(
sympy.Lt(query.get_size()[-2], 128)
)
non_zero_length = V.graph.sizevars.evaluate_expr(sympy.Gt(query.get_size()[-2], 0))
static_batch = isinstance(query.get_size()[0], (int, sympy.Integer))
static_num_heads = isinstance(query.get_size()[1], (int, sympy.Integer))
if enable_gqa:
# in the current flex decoding triton kernel, grouped query heads for the
# same kv head are handled by the same block. So it's hard to support different
# kv num blocks for grouped query heads. We just fall back to main flex_attention
# kernel where each query head is handled by a separate block.
valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr(
sympy.Eq(kv_indices.get_size()[1], 1)
)
else:
valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr(
sympy.Or(
sympy.Eq(kv_indices.get_size()[1], 1),
sympy.Eq(kv_indices.get_size()[1], query.get_size()[1]),
)
)
return (
not force_flex
and short_query_length
and static_batch
and static_num_heads
and non_zero_length
and valid_block_mask_num_heads
)
@SymbolicGridFn
def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
"""How is this kernel parallelized?
@ -322,8 +366,7 @@ def get_split_k(B: int, H: int, Mk: int) -> int:
def create_flex_decoding_kernel(*args, **kwargs):
from .flex_attention import set_head_dim_values
"""Flex decode lowering that is optimized for small Q_LEN and GQA packing"""
(
query,
key,