[FlexAttention] Add mechanism to get optimal autotune decision (#165817)

Script: https://github.com/meta-pytorch/attention-gym/pull/169

Feels directionally okay but there is some bike shedding / this could be quite prone to collision of keys depending on mask mod and score mod changes and simple cache key.

Usecase: https://github.com/meta-pytorch/attention-gym/pull/169

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165817
Approved by: https://github.com/Chillee
This commit is contained in:
drisspg 2025-10-28 02:58:04 +00:00 committed by PyTorch MergeBot
parent 544b443ea1
commit 5016e7b2eb
2 changed files with 225 additions and 0 deletions

View File

@ -2,8 +2,11 @@
# flake8: noqa: B950
import functools
import json
import os
import random
import string
import tempfile
import unittest
import warnings
from collections import namedtuple
@ -7045,6 +7048,120 @@ class TestLearnableBiases(InductorTestCase):
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
self._test_flex_attention_with_dynamic_max_autotune(device)
@skip_on_cpu
def test_flex_attention_logging(self, device):
with tempfile.TemporaryDirectory() as tmpdir:
log_file = os.path.join(tmpdir, "flex_attention_configs")
with patch.dict(
os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}
):
query = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
key = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
value = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
def score_mod(score, b, h, q_idx, kv_idx):
return score * 2
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = torch.compile(create_block_mask)(
causal_mask, 1, 1, 128, 128, device=device
)
compiled_flex = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
out = compiled_flex(
query=query,
key=key,
value=value,
score_mod=score_mod,
block_mask=block_mask,
)
out.sum().backward()
json_file = log_file + ".json"
self.assertTrue(
os.path.exists(json_file), f"Log file {json_file} was not created"
)
with open(json_file) as f:
log_data = json.load(f)
self.assertIsInstance(log_data, list)
self.assertEqual(len(log_data), 2)
keys_seen = [next(iter(entry.keys())) for entry in log_data]
expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)"
expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)"
self.assertIn(expected_fwd_key, keys_seen)
self.assertIn(expected_bwd_key, keys_seen)
for entry in log_data:
self.assertIsInstance(entry, dict)
self.assertEqual(len(entry), 1)
dims_key = next(iter(entry.keys()))
choices = entry[dims_key]
kernel_type = eval(dims_key)[0]
self.assertIsInstance(choices, list)
self.assertGreater(len(choices), 0)
for i, choice in enumerate(choices):
self.assertIn("type", choice)
self.assertIn("time", choice)
if choice["type"] == "triton":
self.assertIn("num_warps", choice)
self.assertIn("num_stages", choice)
if kernel_type == "forward":
self.assertIn("BLOCK_M", choice)
self.assertIn("BLOCK_N", choice)
self.assertNotIn("BLOCK_M1", choice)
elif kernel_type == "backward":
self.assertIn("BLOCK_M1", choice)
self.assertIn("BLOCK_N1", choice)
self.assertIn("BLOCK_M2", choice)
self.assertIn("BLOCK_N2", choice)
self.assertNotIn("BLOCK_M", choice)
self.assertNotIn("BLOCK_N", choice)
if i > 0:
self.assertLessEqual(choices[0]["time"], choice["time"])
@skip_on_cpu
def test_inspect_bug(self, device):
# https://github.com/pytorch/pytorch/issues/139374

View File

@ -17,6 +17,7 @@ import time
from collections.abc import Sequence
from concurrent.futures import as_completed, ThreadPoolExecutor
from io import StringIO
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
from typing_extensions import Self
@ -2104,6 +2105,11 @@ class TritonTemplate(KernelTemplate):
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0),
"kpack": kwargs.get("kpack", 2),
**{
k: kwargs[k]
for k in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS
if k in kwargs
},
},
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,
@ -2397,6 +2403,17 @@ def get_mm_log_filename() -> Optional[str]:
return mm_file_name
@functools.cache
def get_flex_attention_log_filename() -> Optional[str]:
flex_attention_file_name = os.environ.get(
"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None
)
if not flex_attention_file_name:
return None
return str(Path(flex_attention_file_name).with_suffix(".json"))
def append_to_log(filename, data):
lock_file = filename.replace(".json", ".lock")
lock = FileLock(lock_file)
@ -2607,6 +2624,25 @@ class AlgorithmSelectorCache(PersistentCache):
doesn't depend on the output layout.
"""
FLEX_ATTENTION_TUNABLE_KEYS = tuple(
dict.fromkeys(
[
"num_warps",
"num_stages",
"BLOCK_M",
"BLOCK_N",
"BLOCK_M1",
"BLOCK_N1",
"BLOCK_M2",
"BLOCK_N2",
"USE_TMA",
"kpack",
"matrix_instr_nonkdim",
"waves_per_eu",
]
)
)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -3540,6 +3576,73 @@ class AlgorithmSelectorCache(PersistentCache):
)
return pruned_choices
@staticmethod
def get_flex_attention_choice_info(
choice: ChoiceCaller, timings: dict[ChoiceCaller, float]
) -> dict[str, Any]:
if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
return {"type": "extern", "time": timings[choice]}
assert isinstance(choice, torch._inductor.select_algorithm.TritonTemplateCaller)
info = choice.info_dict()
result = {
"type": "triton",
"time": timings[choice],
}
for key in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS:
if key in info:
result[key] = info[key]
return result
@staticmethod
def maybe_log_flex_attention_results(
name: str, input_nodes: list[ir.IRNode], timings: dict[ChoiceCaller, float]
) -> None:
flex_attention_filename = get_flex_attention_log_filename()
if not flex_attention_filename or "flex_attention" not in name:
return
if len(input_nodes) < 3:
return
query_size = input_nodes[0].get_size()
key_size = input_nodes[1].get_size()
value_size = input_nodes[2].get_size()
B = query_size[0]
Hq = query_size[1]
seq_len_q = query_size[2]
qk_head_dim = query_size[3]
Hkv = key_size[1]
seq_len_kv = key_size[2]
v_head_dim = value_size[3]
kernel_type = "backward" if "backward" in name else "forward"
dims_key = str(
(
kernel_type,
B,
Hq,
Hkv,
seq_len_q,
seq_len_kv,
qk_head_dim,
v_head_dim,
)
)
sorted_choices = sorted(timings, key=timings.__getitem__)
out_dict = {
dims_key: [
AlgorithmSelectorCache.get_flex_attention_choice_info(choice, timings)
for choice in sorted_choices
]
}
append_to_log(flex_attention_filename, out_dict)
@staticmethod
def log_results(
name: str,
@ -3550,6 +3653,7 @@ class AlgorithmSelectorCache(PersistentCache):
prescreening_elapse: Optional[float] = None,
hint_override: Optional[int] = None,
):
"""Log the autotuning results, currently only handles mm and flex"""
V.debug.log_autotuning_results(
name, input_nodes, timings, elapse, precompile_elapse
)
@ -3618,6 +3722,10 @@ class AlgorithmSelectorCache(PersistentCache):
append_to_log(mm_filename, out_dict)
AlgorithmSelectorCache.maybe_log_flex_attention_results(
name, input_nodes, timings
)
best_time = timings[best]
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
sys.stderr.write(f"strides: {strides}\n")