[Graph Partition][Flex Attention] analyze symints from subgraph inputs and outputs (#152878)

Flex Attention may have symints in subgraph inputs and outputs. Existing code implicitly captures these symints but does not explicitly store it in TritonTemplateBuffer. This leads to error when analyzing symints used in Flex Attention as a TritonTemplateBuffer. This PR fixes the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152878
Approved by: https://github.com/drisspg
This commit is contained in:
Boyuan Feng 2025-05-08 20:25:35 +00:00 committed by PyTorch MergeBot
parent 6ae7730eeb
commit 590965f92f
4 changed files with 118 additions and 17 deletions

View File

@ -5504,8 +5504,7 @@ class TestLearnableBiases(InductorTestCase):
out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"]
)
@skip_on_cpu
def test_flex_attention_with_dynamic_max_autotune(self, device):
def _test_flex_attention_with_dynamic_max_autotune(self, device):
query = torch.randn(2, 16, 512, 64, device=device)
key = torch.randn(2, 16, 512, 64, device=device)
value = torch.randn(2, 16, 512, 64, device=device)
@ -5545,6 +5544,15 @@ class TestLearnableBiases(InductorTestCase):
out.shape, query.shape, f"Expected shape {query.shape}, got {out.shape}"
)
@skip_on_cpu
def test_flex_attention_with_dynamic_max_autotune(self, device):
self._test_flex_attention_with_dynamic_max_autotune(device)
@skip_on_cpu
@torch._inductor.config.patch("graph_partition", True)
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
self._test_flex_attention_with_dynamic_max_autotune(device)
@skip_on_cpu
def test_inspect_bug(self, device):
# https://github.com/pytorch/pytorch/issues/139374

View File

@ -4594,6 +4594,32 @@ class TritonTemplateBuffer(TemplateBuffer):
allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
)
self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
res = super().get_free_symbol_uses(unbacked_only)
subgraph_outs = self.subgraph_outs if self.subgraph_outs else []
subgraph_inps = self.subgraph_inps if self.subgraph_inps else []
for inp in subgraph_inps:
if isinstance(inp, sympy.Expr):
res.update(get_free_symbols(inp, unbacked_only))
elif isinstance(inp, IRNode):
res.update(inp.get_free_symbol_uses(unbacked_only))
else:
assert inp is None
for out in subgraph_outs:
if isinstance(out, IRNode):
res.update(out.get_free_symbol_uses(unbacked_only))
else:
assert out is None
return res
def get_outputs(self) -> list[Buffer]:
return self.outputs

View File

@ -255,6 +255,20 @@ def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> Subgraph
return build_subgraph_module_buffer(args, subgraph.graph_module)
def get_fwd_subgraph_outputs(
subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults
) -> list[Optional[ComputedBuffer]]:
subgraph_buffer = (
subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer]
)
mask_graph_buffer = (
mask_graph_buffer
if isinstance(mask_graph_buffer, Sequence)
else [mask_graph_buffer]
)
return [*subgraph_buffer, *mask_graph_buffer]
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
compute_next_offset_func = r"""
@triton.jit
@ -1222,6 +1236,15 @@ def lower_cpu(
inputs_for_autotuning,
layout,
)
# need subgraph inputs and outputs to analyze all symints used in flex attention
res.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
mask_mod_other_buffers
)
res.data.data.subgraph_outs = get_fwd_subgraph_outputs(
subgraph_buffer, mask_graph_buffer
)
return (res,)
@ -1570,23 +1593,27 @@ def flex_attention(
6: create_num_blocks_fake_generator(full_kv_indices),
7: create_indices_fake,
}
return (
autotune_select_algorithm(
"flex_attention",
choices,
# Need to filter out symbols since there is an invariant
# that all input_nodes are of type IRNode
[
x
for x in inputs_for_autotuning
if isinstance(x, torch._inductor.ir.IRNode)
],
layout,
input_gen_fns=input_gen_fns,
),
logsumexp,
out = autotune_select_algorithm(
"flex_attention",
choices,
# Need to filter out symbols since there is an invariant
# that all input_nodes are of type IRNode
[x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)],
layout,
input_gen_fns=input_gen_fns,
)
# need subgraph inputs and outputs to analyze all symints used in flex attention
out.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
mask_mod_other_buffers
)
out.data.data.subgraph_outs = get_fwd_subgraph_outputs(
subgraph_buffer, mask_graph_buffer
)
return (out, logsumexp)
# ---------------------------- Backward HOP Implementation ----------------------------
@ -2715,6 +2742,14 @@ def flex_attention_backward(*args, **kwargs):
input_gen_fns=input_gen_fns,
) # [Bq, Hkv, seq_len_kv, k_head_dim]
# need subgraph inputs and outputs to analyze all symints used in flex attention
broadcasted_grad_key.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
mask_mod_other_buffers
)
broadcasted_grad_key.data.data.subgraph_outs = get_bwd_subgraph_outputs(
fw_subgraph_buffer, mask_graph_buffer, joint_outputs
)
if V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv)):
grad_key = broadcasted_grad_key
grad_value = broadcasted_grad_value
@ -2728,3 +2763,26 @@ def flex_attention_backward(*args, **kwargs):
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)
return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads))
def get_bwd_subgraph_outputs(
subgraph_buffer: SubgraphResults,
mask_graph_buffer: SubgraphResults,
joint_outputs: JointOutputResult,
) -> list[Optional[Union[ComputedBuffer, TensorBox]]]:
subgraph_buffer = (
subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer]
)
mask_graph_buffer = (
mask_graph_buffer
if isinstance(mask_graph_buffer, Sequence)
else [mask_graph_buffer]
)
joint_output_buffers = [
joint_outputs.grad_input,
*joint_outputs.captured_grads_compute,
*joint_outputs.captured_grads,
*joint_outputs.mutated_grads,
]
return [*subgraph_buffer, *mask_graph_buffer, *joint_output_buffers]

View File

@ -20,6 +20,7 @@ from .flex_attention import (
create_indices_fake,
create_num_blocks_fake_generator,
get_bounded_indices_func,
get_fwd_subgraph_outputs,
load_checked_2d,
load_checked_block,
maybe_realize,
@ -604,6 +605,14 @@ def create_flex_decoding_kernel(*args, **kwargs):
input_gen_fns=input_gen_fns,
)
# need subgraph inputs and outputs to analyze all symints used in flex attention
buf_ACC.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
mask_mod_other_buffers
)
buf_ACC.data.data.subgraph_outs = get_fwd_subgraph_outputs(
score_mod_subgraph, mask_mod_subgraph
)
# Reduction
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]