mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Graph Partition][Flex Attention] analyze symints from subgraph inputs and outputs (#152878)
Flex Attention may have symints in subgraph inputs and outputs. Existing code implicitly captures these symints but does not explicitly store it in TritonTemplateBuffer. This leads to error when analyzing symints used in Flex Attention as a TritonTemplateBuffer. This PR fixes the issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152878 Approved by: https://github.com/drisspg
This commit is contained in:
parent
6ae7730eeb
commit
590965f92f
|
|
@ -5504,8 +5504,7 @@ class TestLearnableBiases(InductorTestCase):
|
|||
out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"]
|
||||
)
|
||||
|
||||
@skip_on_cpu
|
||||
def test_flex_attention_with_dynamic_max_autotune(self, device):
|
||||
def _test_flex_attention_with_dynamic_max_autotune(self, device):
|
||||
query = torch.randn(2, 16, 512, 64, device=device)
|
||||
key = torch.randn(2, 16, 512, 64, device=device)
|
||||
value = torch.randn(2, 16, 512, 64, device=device)
|
||||
|
|
@ -5545,6 +5544,15 @@ class TestLearnableBiases(InductorTestCase):
|
|||
out.shape, query.shape, f"Expected shape {query.shape}, got {out.shape}"
|
||||
)
|
||||
|
||||
@skip_on_cpu
|
||||
def test_flex_attention_with_dynamic_max_autotune(self, device):
|
||||
self._test_flex_attention_with_dynamic_max_autotune(device)
|
||||
|
||||
@skip_on_cpu
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
|
||||
self._test_flex_attention_with_dynamic_max_autotune(device)
|
||||
|
||||
@skip_on_cpu
|
||||
def test_inspect_bug(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/139374
|
||||
|
|
|
|||
|
|
@ -4594,6 +4594,32 @@ class TritonTemplateBuffer(TemplateBuffer):
|
|||
allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
|
||||
)
|
||||
|
||||
self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
|
||||
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
res = super().get_free_symbol_uses(unbacked_only)
|
||||
subgraph_outs = self.subgraph_outs if self.subgraph_outs else []
|
||||
subgraph_inps = self.subgraph_inps if self.subgraph_inps else []
|
||||
|
||||
for inp in subgraph_inps:
|
||||
if isinstance(inp, sympy.Expr):
|
||||
res.update(get_free_symbols(inp, unbacked_only))
|
||||
elif isinstance(inp, IRNode):
|
||||
res.update(inp.get_free_symbol_uses(unbacked_only))
|
||||
else:
|
||||
assert inp is None
|
||||
|
||||
for out in subgraph_outs:
|
||||
if isinstance(out, IRNode):
|
||||
res.update(out.get_free_symbol_uses(unbacked_only))
|
||||
else:
|
||||
assert out is None
|
||||
|
||||
return res
|
||||
|
||||
def get_outputs(self) -> list[Buffer]:
|
||||
return self.outputs
|
||||
|
||||
|
|
|
|||
|
|
@ -255,6 +255,20 @@ def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> Subgraph
|
|||
return build_subgraph_module_buffer(args, subgraph.graph_module)
|
||||
|
||||
|
||||
def get_fwd_subgraph_outputs(
|
||||
subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults
|
||||
) -> list[Optional[ComputedBuffer]]:
|
||||
subgraph_buffer = (
|
||||
subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer]
|
||||
)
|
||||
mask_graph_buffer = (
|
||||
mask_graph_buffer
|
||||
if isinstance(mask_graph_buffer, Sequence)
|
||||
else [mask_graph_buffer]
|
||||
)
|
||||
return [*subgraph_buffer, *mask_graph_buffer]
|
||||
|
||||
|
||||
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
|
||||
compute_next_offset_func = r"""
|
||||
@triton.jit
|
||||
|
|
@ -1222,6 +1236,15 @@ def lower_cpu(
|
|||
inputs_for_autotuning,
|
||||
layout,
|
||||
)
|
||||
|
||||
# need subgraph inputs and outputs to analyze all symints used in flex attention
|
||||
res.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
|
||||
mask_mod_other_buffers
|
||||
)
|
||||
res.data.data.subgraph_outs = get_fwd_subgraph_outputs(
|
||||
subgraph_buffer, mask_graph_buffer
|
||||
)
|
||||
|
||||
return (res,)
|
||||
|
||||
|
||||
|
|
@ -1570,23 +1593,27 @@ def flex_attention(
|
|||
6: create_num_blocks_fake_generator(full_kv_indices),
|
||||
7: create_indices_fake,
|
||||
}
|
||||
return (
|
||||
autotune_select_algorithm(
|
||||
"flex_attention",
|
||||
choices,
|
||||
# Need to filter out symbols since there is an invariant
|
||||
# that all input_nodes are of type IRNode
|
||||
[
|
||||
x
|
||||
for x in inputs_for_autotuning
|
||||
if isinstance(x, torch._inductor.ir.IRNode)
|
||||
],
|
||||
layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
),
|
||||
logsumexp,
|
||||
|
||||
out = autotune_select_algorithm(
|
||||
"flex_attention",
|
||||
choices,
|
||||
# Need to filter out symbols since there is an invariant
|
||||
# that all input_nodes are of type IRNode
|
||||
[x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)],
|
||||
layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
# need subgraph inputs and outputs to analyze all symints used in flex attention
|
||||
out.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
|
||||
mask_mod_other_buffers
|
||||
)
|
||||
out.data.data.subgraph_outs = get_fwd_subgraph_outputs(
|
||||
subgraph_buffer, mask_graph_buffer
|
||||
)
|
||||
|
||||
return (out, logsumexp)
|
||||
|
||||
|
||||
# ---------------------------- Backward HOP Implementation ----------------------------
|
||||
|
||||
|
|
@ -2715,6 +2742,14 @@ def flex_attention_backward(*args, **kwargs):
|
|||
input_gen_fns=input_gen_fns,
|
||||
) # [Bq, Hkv, seq_len_kv, k_head_dim]
|
||||
|
||||
# need subgraph inputs and outputs to analyze all symints used in flex attention
|
||||
broadcasted_grad_key.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
|
||||
mask_mod_other_buffers
|
||||
)
|
||||
broadcasted_grad_key.data.data.subgraph_outs = get_bwd_subgraph_outputs(
|
||||
fw_subgraph_buffer, mask_graph_buffer, joint_outputs
|
||||
)
|
||||
|
||||
if V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv)):
|
||||
grad_key = broadcasted_grad_key
|
||||
grad_value = broadcasted_grad_value
|
||||
|
|
@ -2728,3 +2763,26 @@ def flex_attention_backward(*args, **kwargs):
|
|||
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)
|
||||
|
||||
return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads))
|
||||
|
||||
|
||||
def get_bwd_subgraph_outputs(
|
||||
subgraph_buffer: SubgraphResults,
|
||||
mask_graph_buffer: SubgraphResults,
|
||||
joint_outputs: JointOutputResult,
|
||||
) -> list[Optional[Union[ComputedBuffer, TensorBox]]]:
|
||||
subgraph_buffer = (
|
||||
subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer]
|
||||
)
|
||||
mask_graph_buffer = (
|
||||
mask_graph_buffer
|
||||
if isinstance(mask_graph_buffer, Sequence)
|
||||
else [mask_graph_buffer]
|
||||
)
|
||||
joint_output_buffers = [
|
||||
joint_outputs.grad_input,
|
||||
*joint_outputs.captured_grads_compute,
|
||||
*joint_outputs.captured_grads,
|
||||
*joint_outputs.mutated_grads,
|
||||
]
|
||||
|
||||
return [*subgraph_buffer, *mask_graph_buffer, *joint_output_buffers]
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .flex_attention import (
|
|||
create_indices_fake,
|
||||
create_num_blocks_fake_generator,
|
||||
get_bounded_indices_func,
|
||||
get_fwd_subgraph_outputs,
|
||||
load_checked_2d,
|
||||
load_checked_block,
|
||||
maybe_realize,
|
||||
|
|
@ -604,6 +605,14 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
# need subgraph inputs and outputs to analyze all symints used in flex attention
|
||||
buf_ACC.data.data.subgraph_inps = list(score_mod_other_buffers) + list(
|
||||
mask_mod_other_buffers
|
||||
)
|
||||
buf_ACC.data.data.subgraph_outs = get_fwd_subgraph_outputs(
|
||||
score_mod_subgraph, mask_mod_subgraph
|
||||
)
|
||||
|
||||
# Reduction
|
||||
|
||||
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user