mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[FlexAttention] Add mechanism to get optimal autotune decision (#165817)
Script: https://github.com/meta-pytorch/attention-gym/pull/169 Feels directionally okay but there is some bike shedding / this could be quite prone to collision of keys depending on mask mod and score mod changes and simple cache key. Usecase: https://github.com/meta-pytorch/attention-gym/pull/169 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165817 Approved by: https://github.com/Chillee
This commit is contained in:
parent
544b443ea1
commit
5016e7b2eb
|
|
@ -2,8 +2,11 @@
|
|||
# flake8: noqa: B950
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
|
|
@ -7045,6 +7048,120 @@ class TestLearnableBiases(InductorTestCase):
|
|||
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
|
||||
self._test_flex_attention_with_dynamic_max_autotune(device)
|
||||
|
||||
@skip_on_cpu
|
||||
def test_flex_attention_logging(self, device):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
log_file = os.path.join(tmpdir, "flex_attention_configs")
|
||||
|
||||
with patch.dict(
|
||||
os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}
|
||||
):
|
||||
query = torch.randn(
|
||||
1,
|
||||
2,
|
||||
128,
|
||||
64,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
requires_grad=True,
|
||||
)
|
||||
key = torch.randn(
|
||||
1,
|
||||
2,
|
||||
128,
|
||||
64,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
requires_grad=True,
|
||||
)
|
||||
value = torch.randn(
|
||||
1,
|
||||
2,
|
||||
128,
|
||||
64,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
def score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score * 2
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = torch.compile(create_block_mask)(
|
||||
causal_mask, 1, 1, 128, 128, device=device
|
||||
)
|
||||
|
||||
compiled_flex = torch.compile(
|
||||
flex_attention, mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
|
||||
out = compiled_flex(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
)
|
||||
|
||||
out.sum().backward()
|
||||
|
||||
json_file = log_file + ".json"
|
||||
self.assertTrue(
|
||||
os.path.exists(json_file), f"Log file {json_file} was not created"
|
||||
)
|
||||
|
||||
with open(json_file) as f:
|
||||
log_data = json.load(f)
|
||||
|
||||
self.assertIsInstance(log_data, list)
|
||||
self.assertEqual(len(log_data), 2)
|
||||
|
||||
keys_seen = [next(iter(entry.keys())) for entry in log_data]
|
||||
|
||||
expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)"
|
||||
expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)"
|
||||
|
||||
self.assertIn(expected_fwd_key, keys_seen)
|
||||
self.assertIn(expected_bwd_key, keys_seen)
|
||||
|
||||
for entry in log_data:
|
||||
self.assertIsInstance(entry, dict)
|
||||
self.assertEqual(len(entry), 1)
|
||||
|
||||
dims_key = next(iter(entry.keys()))
|
||||
choices = entry[dims_key]
|
||||
|
||||
kernel_type = eval(dims_key)[0]
|
||||
|
||||
self.assertIsInstance(choices, list)
|
||||
self.assertGreater(len(choices), 0)
|
||||
|
||||
for i, choice in enumerate(choices):
|
||||
self.assertIn("type", choice)
|
||||
self.assertIn("time", choice)
|
||||
|
||||
if choice["type"] == "triton":
|
||||
self.assertIn("num_warps", choice)
|
||||
self.assertIn("num_stages", choice)
|
||||
|
||||
if kernel_type == "forward":
|
||||
self.assertIn("BLOCK_M", choice)
|
||||
self.assertIn("BLOCK_N", choice)
|
||||
self.assertNotIn("BLOCK_M1", choice)
|
||||
elif kernel_type == "backward":
|
||||
self.assertIn("BLOCK_M1", choice)
|
||||
self.assertIn("BLOCK_N1", choice)
|
||||
self.assertIn("BLOCK_M2", choice)
|
||||
self.assertIn("BLOCK_N2", choice)
|
||||
self.assertNotIn("BLOCK_M", choice)
|
||||
self.assertNotIn("BLOCK_N", choice)
|
||||
|
||||
if i > 0:
|
||||
self.assertLessEqual(choices[0]["time"], choice["time"])
|
||||
|
||||
@skip_on_cpu
|
||||
def test_inspect_bug(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/139374
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import time
|
|||
from collections.abc import Sequence
|
||||
from concurrent.futures import as_completed, ThreadPoolExecutor
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Self
|
||||
|
|
@ -2104,6 +2105,11 @@ class TritonTemplate(KernelTemplate):
|
|||
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
||||
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
||||
"kpack": kwargs.get("kpack", 2),
|
||||
**{
|
||||
k: kwargs[k]
|
||||
for k in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS
|
||||
if k in kwargs
|
||||
},
|
||||
},
|
||||
mutated_inputs=mutated_inputs,
|
||||
workspace_arg=workspace_arg,
|
||||
|
|
@ -2397,6 +2403,17 @@ def get_mm_log_filename() -> Optional[str]:
|
|||
return mm_file_name
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_flex_attention_log_filename() -> Optional[str]:
|
||||
flex_attention_file_name = os.environ.get(
|
||||
"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None
|
||||
)
|
||||
if not flex_attention_file_name:
|
||||
return None
|
||||
|
||||
return str(Path(flex_attention_file_name).with_suffix(".json"))
|
||||
|
||||
|
||||
def append_to_log(filename, data):
|
||||
lock_file = filename.replace(".json", ".lock")
|
||||
lock = FileLock(lock_file)
|
||||
|
|
@ -2607,6 +2624,25 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
doesn't depend on the output layout.
|
||||
"""
|
||||
|
||||
FLEX_ATTENTION_TUNABLE_KEYS = tuple(
|
||||
dict.fromkeys(
|
||||
[
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
"BLOCK_M1",
|
||||
"BLOCK_N1",
|
||||
"BLOCK_M2",
|
||||
"BLOCK_N2",
|
||||
"USE_TMA",
|
||||
"kpack",
|
||||
"matrix_instr_nonkdim",
|
||||
"waves_per_eu",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
@ -3540,6 +3576,73 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
)
|
||||
return pruned_choices
|
||||
|
||||
@staticmethod
|
||||
def get_flex_attention_choice_info(
|
||||
choice: ChoiceCaller, timings: dict[ChoiceCaller, float]
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
|
||||
return {"type": "extern", "time": timings[choice]}
|
||||
|
||||
assert isinstance(choice, torch._inductor.select_algorithm.TritonTemplateCaller)
|
||||
|
||||
info = choice.info_dict()
|
||||
result = {
|
||||
"type": "triton",
|
||||
"time": timings[choice],
|
||||
}
|
||||
|
||||
for key in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS:
|
||||
if key in info:
|
||||
result[key] = info[key]
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def maybe_log_flex_attention_results(
|
||||
name: str, input_nodes: list[ir.IRNode], timings: dict[ChoiceCaller, float]
|
||||
) -> None:
|
||||
flex_attention_filename = get_flex_attention_log_filename()
|
||||
if not flex_attention_filename or "flex_attention" not in name:
|
||||
return
|
||||
|
||||
if len(input_nodes) < 3:
|
||||
return
|
||||
|
||||
query_size = input_nodes[0].get_size()
|
||||
key_size = input_nodes[1].get_size()
|
||||
value_size = input_nodes[2].get_size()
|
||||
|
||||
B = query_size[0]
|
||||
Hq = query_size[1]
|
||||
seq_len_q = query_size[2]
|
||||
qk_head_dim = query_size[3]
|
||||
Hkv = key_size[1]
|
||||
seq_len_kv = key_size[2]
|
||||
v_head_dim = value_size[3]
|
||||
|
||||
kernel_type = "backward" if "backward" in name else "forward"
|
||||
dims_key = str(
|
||||
(
|
||||
kernel_type,
|
||||
B,
|
||||
Hq,
|
||||
Hkv,
|
||||
seq_len_q,
|
||||
seq_len_kv,
|
||||
qk_head_dim,
|
||||
v_head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_choices = sorted(timings, key=timings.__getitem__)
|
||||
out_dict = {
|
||||
dims_key: [
|
||||
AlgorithmSelectorCache.get_flex_attention_choice_info(choice, timings)
|
||||
for choice in sorted_choices
|
||||
]
|
||||
}
|
||||
append_to_log(flex_attention_filename, out_dict)
|
||||
|
||||
@staticmethod
|
||||
def log_results(
|
||||
name: str,
|
||||
|
|
@ -3550,6 +3653,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
prescreening_elapse: Optional[float] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
"""Log the autotuning results, currently only handles mm and flex"""
|
||||
V.debug.log_autotuning_results(
|
||||
name, input_nodes, timings, elapse, precompile_elapse
|
||||
)
|
||||
|
|
@ -3618,6 +3722,10 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
|
||||
append_to_log(mm_filename, out_dict)
|
||||
|
||||
AlgorithmSelectorCache.maybe_log_flex_attention_results(
|
||||
name, input_nodes, timings
|
||||
)
|
||||
|
||||
best_time = timings[best]
|
||||
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
|
||||
sys.stderr.write(f"strides: {strides}\n")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user