mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add support for capturing tensors with score_mod (#124444)
```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')
import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel
index = torch.ops.aten
torch.manual_seed(0)
B = 16
H = 16
S = 2048
D = 64
head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')
query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">
Differential Revision: [D56583900](https://our.internmc.facebook.com/intern/diff/D56583900)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
This commit is contained in:
parent
3a810bcf91
commit
7321005dd8
|
|
@ -4,7 +4,7 @@ import functools
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from unittest import expectedFailure, skipUnless
|
from unittest import skip, skipUnless
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -36,6 +36,8 @@ supported_platform = skipUnless(
|
||||||
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
index = torch.ops.aten.index
|
||||||
|
|
||||||
|
|
||||||
def create_attention(score_mod):
|
def create_attention(score_mod):
|
||||||
return functools.partial(_templated_attention, score_mod=score_mod)
|
return functools.partial(_templated_attention, score_mod=score_mod)
|
||||||
|
|
@ -47,6 +49,8 @@ test_dtypes = (
|
||||||
else [torch.float16, torch.float32]
|
else [torch.float16, torch.float32]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
test_dtypes_fast = [torch.float16]
|
||||||
|
|
||||||
# TODO float16 was causing ERRORs for tests on ROCm
|
# TODO float16 was causing ERRORs for tests on ROCm
|
||||||
# See https://github.com/pytorch/pytorch/issues/123531
|
# See https://github.com/pytorch/pytorch/issues/123531
|
||||||
if common_utils.TEST_WITH_ROCM:
|
if common_utils.TEST_WITH_ROCM:
|
||||||
|
|
@ -65,13 +69,19 @@ def _causal_mod(score, b, h, token_q, token_kv):
|
||||||
return torch.where(token_q >= token_kv, score, float("-inf"))
|
return torch.where(token_q >= token_kv, score, float("-inf"))
|
||||||
|
|
||||||
|
|
||||||
|
B = 4
|
||||||
|
H = 8
|
||||||
|
S = 2048
|
||||||
|
D = 64
|
||||||
|
|
||||||
|
|
||||||
class TestTemplatedSDPA(InductorTestCase):
|
class TestTemplatedSDPA(InductorTestCase):
|
||||||
def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):
|
def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):
|
||||||
sdpa_partial = create_attention(score_mod)
|
sdpa_partial = create_attention(score_mod)
|
||||||
compiled_sdpa = torch.compile(sdpa_partial)
|
compiled_sdpa = torch.compile(sdpa_partial)
|
||||||
q = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
q = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||||
k = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
k = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||||
v = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
v = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||||
golden_out = sdpa_partial(
|
golden_out = sdpa_partial(
|
||||||
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64)
|
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64)
|
||||||
)
|
)
|
||||||
|
|
@ -109,23 +119,116 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||||
|
|
||||||
self.run_test(composed_score_mod, dtype)
|
self.run_test(composed_score_mod, dtype)
|
||||||
|
|
||||||
# TODO We are currently not capturing free variables in the closure correctly
|
|
||||||
@expectedFailure
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
@common_utils.parametrize("dtype", test_dtypes)
|
@common_utils.parametrize("dtype", test_dtypes)
|
||||||
def test_captured_buffers(self, dtype: torch.dtype):
|
def test_captured_buffers(self, dtype: torch.dtype):
|
||||||
head_offset = torch.rand(8, device="cuda", dtype=dtype)
|
head_offset = torch.rand(H, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
def score_mod(score, b, h, m, n):
|
def score_mod(score, b, h, m, n):
|
||||||
return score + head_offset[h]
|
return score + index(head_offset, [h])
|
||||||
|
|
||||||
self.run_test(score_mod, dtype)
|
self.run_test(score_mod, dtype)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||||
|
def test_seq_masking(self, dtype):
|
||||||
|
seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
|
||||||
|
seq_idx[S // 2 :] = 1
|
||||||
|
|
||||||
|
def seq_mask_mod(score, b, h, q, kv):
|
||||||
|
return torch.where(
|
||||||
|
index(seq_idx, [q]) == index(seq_idx, [kv]), score, float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run_test(seq_mask_mod, dtype)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||||
|
def test_load_from_bias_seq_only(self, dtype):
|
||||||
|
bias = torch.randn(S, S, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
def bias_mod(score, b, h, q, kv):
|
||||||
|
return score + index(bias, [q, kv])
|
||||||
|
|
||||||
|
self.run_test(bias_mod, dtype)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||||
|
def test_load_from_bias_seq_batch(self, dtype):
|
||||||
|
bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
def bias_mod(score, b, h, q, kv):
|
||||||
|
return score + index(bias, [b, q, kv])
|
||||||
|
|
||||||
|
self.run_test(bias_mod, dtype)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||||
|
def test_load_from_bias_head_seq_batch(self, dtype):
|
||||||
|
bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
def bias_mod(score, b, h, q, kv):
|
||||||
|
return score + index(bias, [b, h, q, kv])
|
||||||
|
|
||||||
|
self.run_test(bias_mod, dtype)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||||
|
def test_load_rel_bias(self, dtype):
|
||||||
|
rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
def bias_mod(score, b, h, q, kv):
|
||||||
|
return score + index(rel_bias, [(q - kv) + S])
|
||||||
|
|
||||||
|
self.run_test(bias_mod, dtype)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||||
|
def test_dependent_causal_bidirectional(self, dtype):
|
||||||
|
num_bidirectional = torch.randint(0, S, (B,), device="cuda", dtype=torch.int32)
|
||||||
|
|
||||||
|
def bias_mod(score, b, h, q, kv):
|
||||||
|
causal_attention = q >= kv
|
||||||
|
cur_num_bidirectional = index(num_bidirectional, (b,))
|
||||||
|
bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
|
||||||
|
kv <= cur_num_bidirectional
|
||||||
|
)
|
||||||
|
return torch.where(
|
||||||
|
bidirectional_attention_on_video | causal_attention,
|
||||||
|
score,
|
||||||
|
-float("inf"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run_test(bias_mod, dtype)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
|
||||||
|
@common_utils.parametrize("dtype", test_dtypes)
|
||||||
|
def test_njt_causal(self, dtype):
|
||||||
|
offsets = torch.tensor(
|
||||||
|
[0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
|
||||||
|
)
|
||||||
|
seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
|
||||||
|
for idx in range(len(offsets) - 1):
|
||||||
|
seq_idx[offsets[idx] : offsets[idx + 1]] = idx
|
||||||
|
|
||||||
|
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
|
||||||
|
def njt_score_mod(qk, b, h, q, kv):
|
||||||
|
q_nested = q - index(offsets, [index(seq_idx, [q])])
|
||||||
|
kv_nested = kv - index(offsets, [index(seq_idx, [kv])])
|
||||||
|
return orig_score_mod(qk, b, h, q_nested, kv_nested)
|
||||||
|
|
||||||
|
return njt_score_mod
|
||||||
|
|
||||||
|
causal_njt = create_njt_wrapper(_causal_mod, offsets, seq_idx)
|
||||||
|
|
||||||
|
self.run_test(causal_njt, dtype)
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
def test_backwards_fails(self):
|
def test_backwards_fails(self):
|
||||||
make_tensor = functools.partial(
|
make_tensor = functools.partial(
|
||||||
torch.randn,
|
torch.randn,
|
||||||
(4, 8, 2048, 64),
|
(B, H, S, D),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
|
|
@ -139,9 +242,9 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
def test_mixed_dtypes_fails(self):
|
def test_mixed_dtypes_fails(self):
|
||||||
query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda")
|
query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
|
||||||
key = torch.randn((1, 1, 2048, 64), dtype=torch.float16, device="cuda")
|
key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
|
||||||
value = torch.randn((1, 1, 2048, 64), dtype=torch.float16, device="cuda")
|
value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "Expected query, key, and value to have the same dtype"
|
ValueError, "Expected query, key, and value to have the same dtype"
|
||||||
):
|
):
|
||||||
|
|
@ -163,6 +266,21 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||||
|
|
||||||
self.run_test(score_mod)
|
self.run_test(score_mod)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@patch.object(torch._inductor.config, "max_autotune", True)
|
||||||
|
def test_max_autotune_with_captured(self):
|
||||||
|
head_scale = torch.randn(H, device="cuda")
|
||||||
|
batch_scale = torch.randn(B, device="cuda")
|
||||||
|
tok_scale = torch.randn(S, device="cuda")
|
||||||
|
|
||||||
|
def bias_mod(score, batch, head, token_q, token_kv):
|
||||||
|
score = score + index(tok_scale, [token_q])
|
||||||
|
score = score + index(batch_scale, [batch])
|
||||||
|
score = score + index(head_scale, [head])
|
||||||
|
return score
|
||||||
|
|
||||||
|
self.run_test(bias_mod)
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
@common_utils.parametrize("dtype", test_dtypes)
|
@common_utils.parametrize("dtype", test_dtypes)
|
||||||
@common_utils.parametrize("score_mod", [_identity, _causal])
|
@common_utils.parametrize("score_mod", [_identity, _causal])
|
||||||
|
|
@ -173,7 +291,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||||
|
|
||||||
make_tensor = functools.partial(
|
make_tensor = functools.partial(
|
||||||
torch.randn,
|
torch.randn,
|
||||||
(4, 8, 2048, 64),
|
(B, H, S, D),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
|
|
@ -215,7 +333,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||||
def test_logsumexp_only_return(self):
|
def test_logsumexp_only_return(self):
|
||||||
make_tensor = functools.partial(
|
make_tensor = functools.partial(
|
||||||
torch.randn,
|
torch.randn,
|
||||||
(4, 8, 2048, 64),
|
(B, H, S, D),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
|
|
@ -236,7 +354,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||||
def test_logsumexp_is_not_fused(self):
|
def test_logsumexp_is_not_fused(self):
|
||||||
make_tensor = functools.partial(
|
make_tensor = functools.partial(
|
||||||
torch.randn,
|
torch.randn,
|
||||||
(4, 8, 2048, 64),
|
(B, H, S, D),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
|
|
|
||||||
|
|
@ -1536,12 +1536,10 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
from .builder import wrap_fx_proxy
|
from .builder import wrap_fx_proxy
|
||||||
|
|
||||||
query, key, value, score_mod, *other_buffers = self.normalize_to_args(
|
query, key, value, score_mod = self.normalize_to_args(args, kwargs)
|
||||||
args, kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
p_args, p_kwargs = self.create_wrapped_node(tx, query, score_mod)
|
p_args, p_kwargs = self.create_wrapped_node(tx, query, score_mod)
|
||||||
proxied_args = [query, key, value, *other_buffers]
|
proxied_args = [query, key, value]
|
||||||
|
|
||||||
# Store the invocation as a call
|
# Store the invocation as a call
|
||||||
# Norm_kwargs contains the score_function and we dont want to proxy this because
|
# Norm_kwargs contains the score_function and we dont want to proxy this because
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ def math_attention(
|
||||||
"""
|
"""
|
||||||
assert len(other_buffers) == 0, "Other buffers are not yet supported."
|
assert len(other_buffers) == 0, "Other buffers are not yet supported."
|
||||||
|
|
||||||
scores = query @ key.transpose(-2, -1)
|
scores = (query @ key.transpose(-2, -1)).to(dtype=torch.float32)
|
||||||
|
|
||||||
b = torch.arange(0, scores.size(0), device=scores.device)
|
b = torch.arange(0, scores.size(0), device=scores.device)
|
||||||
h = torch.arange(0, scores.size(1), device=scores.device)
|
h = torch.arange(0, scores.size(1), device=scores.device)
|
||||||
|
|
@ -179,9 +179,11 @@ def templated_attention_functionalize(
|
||||||
assert isinstance(other_buffers_unwrapped, tuple)
|
assert isinstance(other_buffers_unwrapped, tuple)
|
||||||
assert all(isinstance(item, torch.Tensor) for item in other_buffers_unwrapped)
|
assert all(isinstance(item, torch.Tensor) for item in other_buffers_unwrapped)
|
||||||
|
|
||||||
example_vals = [torch.zeros((), dtype=query.dtype)] + [
|
example_vals = (
|
||||||
torch.zeros((), dtype=torch.int) for _ in range(4)
|
[torch.zeros((), dtype=query.dtype)]
|
||||||
]
|
+ [torch.zeros((), dtype=torch.int) for _ in range(4)]
|
||||||
|
+ list(other_buffers_unwrapped)
|
||||||
|
)
|
||||||
with ctx.redispatch_to_next() as m:
|
with ctx.redispatch_to_next() as m:
|
||||||
functional_score_mod = ctx.functionalize(score_mod)
|
functional_score_mod = ctx.functionalize(score_mod)
|
||||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||||
|
|
|
||||||
|
|
@ -3413,22 +3413,14 @@ class TritonScheduling(BaseScheduling):
|
||||||
buffer_names.update(node.used_buffer_names())
|
buffer_names.update(node.used_buffer_names())
|
||||||
|
|
||||||
# Get buffers objects
|
# Get buffers objects
|
||||||
def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
|
|
||||||
if name in V.graph.name_to_buffer:
|
|
||||||
return V.graph.name_to_buffer[name]
|
|
||||||
elif name in V.graph.graph_inputs:
|
|
||||||
return V.graph.graph_inputs[name]
|
|
||||||
elif name in V.graph.constants:
|
|
||||||
data = V.graph.constants[name]
|
|
||||||
return ir.ConstantBuffer(
|
|
||||||
name,
|
|
||||||
ir.FixedLayout(
|
|
||||||
data.device, data.dtype, *V.graph.static_sizes_strides(data)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
raise RuntimeError(f"Failed to find buffer matching name {name}")
|
|
||||||
|
|
||||||
buffers = [_get_buffer(name) for name in buffer_names]
|
def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
|
||||||
|
buf = V.graph.get_buffer(name)
|
||||||
|
if buf is None:
|
||||||
|
raise RuntimeError(f"Failed to find buffer matching name {name}")
|
||||||
|
return buf
|
||||||
|
|
||||||
|
buffers = [V.graph.get_buffer(name) for name in buffer_names]
|
||||||
|
|
||||||
# In theory we can separately check xnumel and rnumel are <= int_max
|
# In theory we can separately check xnumel and rnumel are <= int_max
|
||||||
# but some indexers do use the full linear index so we need to be
|
# but some indexers do use the full linear index so we need to be
|
||||||
|
|
|
||||||
|
|
@ -665,6 +665,14 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
return self.name_to_buffer[buffer_name]
|
return self.name_to_buffer[buffer_name]
|
||||||
if buffer_name in self.graph_inputs:
|
if buffer_name in self.graph_inputs:
|
||||||
return self.graph_inputs[buffer_name]
|
return self.graph_inputs[buffer_name]
|
||||||
|
if buffer_name in self.constants:
|
||||||
|
data = V.graph.constants[buffer_name]
|
||||||
|
return ir.ConstantBuffer(
|
||||||
|
buffer_name,
|
||||||
|
ir.FixedLayout(
|
||||||
|
data.device, data.dtype, *V.graph.static_sizes_strides(data)
|
||||||
|
),
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_dtype(self, buffer_name: str):
|
def get_dtype(self, buffer_name: str):
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from .. import config
|
||||||
from ..lowering import empty_strided, lowerings, register_lowering
|
from ..lowering import empty_strided, lowerings, register_lowering
|
||||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||||
|
|
||||||
|
|
@ -114,12 +115,14 @@ sdpa_template = TritonTemplate(
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
|
qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
|
||||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||||
|
m = offs_m[:, None]
|
||||||
|
n = start_n + offs_n[None, :]
|
||||||
{{ modification(
|
{{ modification(
|
||||||
score="qk",
|
score="qk",
|
||||||
b="off_hz // H",
|
b="off_hz // H",
|
||||||
h="off_hz % H",
|
h="off_hz % H",
|
||||||
m="offs_m[:, None]",
|
m="m",
|
||||||
n="start_n + offs_n[None, :]",
|
n="n",
|
||||||
out="qk"
|
out="qk"
|
||||||
) | indent_except_first(2) }}
|
) | indent_except_first(2) }}
|
||||||
# TODO: In the case that score_mod is linear, this can be LICMed
|
# TODO: In the case that score_mod is linear, this can be LICMed
|
||||||
|
|
@ -170,7 +173,8 @@ sdpa_template = TritonTemplate(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_lowering(torch.ops.higher_order.templated_attention)
|
# TODO: We probably also need a layout constraint?
|
||||||
|
@register_lowering(torch.ops.higher_order.templated_attention, type_promotion_kind=None)
|
||||||
def templated_attention(*args, **kwargs):
|
def templated_attention(*args, **kwargs):
|
||||||
from torch._prims_common import make_contiguous_strides_for
|
from torch._prims_common import make_contiguous_strides_for
|
||||||
from ..ir import (
|
from ..ir import (
|
||||||
|
|
@ -182,7 +186,7 @@ def templated_attention(*args, **kwargs):
|
||||||
TensorBox,
|
TensorBox,
|
||||||
)
|
)
|
||||||
|
|
||||||
query, key, value, subgraph = args
|
query, key, value, subgraph, *other_buffers = args
|
||||||
|
|
||||||
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
|
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
|
||||||
return TensorBox.create(
|
return TensorBox.create(
|
||||||
|
|
@ -272,17 +276,23 @@ def templated_attention(*args, **kwargs):
|
||||||
configs: List[Any] = []
|
configs: List[Any] = []
|
||||||
if query.get_dtype() == torch.float32:
|
if query.get_dtype() == torch.float32:
|
||||||
configs.append((64, 64, 4, 3))
|
configs.append((64, 64, 4, 3))
|
||||||
|
else:
|
||||||
|
configs.append((128, 64, 4, 3))
|
||||||
|
if config.max_autotune:
|
||||||
configs += [
|
configs += [
|
||||||
(128, 64, 4, 3),
|
(128, 64, 4, 3),
|
||||||
(128, 128, 4, 3),
|
(128, 128, 4, 3),
|
||||||
(128, 128, 8, 2),
|
(128, 128, 8, 2),
|
||||||
(64, 128, 4, 3),
|
(64, 128, 4, 3),
|
||||||
|
(64, 64, 4, 3),
|
||||||
]
|
]
|
||||||
|
# Note, we don't need to pass in the captured buffers explicitly
|
||||||
|
# because they're implicitly added by the score_mod function
|
||||||
|
# We do need to explicitly pass it in for autotuning though.
|
||||||
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
|
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
|
||||||
sdpa_template.maybe_append_choice(
|
sdpa_template.maybe_append_choice(
|
||||||
choices=choices,
|
choices=choices,
|
||||||
input_nodes=(query, key, value, logsumexp),
|
input_nodes=[query, key, value, logsumexp],
|
||||||
layout=layout,
|
layout=layout,
|
||||||
subgraphs=subgraph_buffer,
|
subgraphs=subgraph_buffer,
|
||||||
mutated_inputs=[
|
mutated_inputs=[
|
||||||
|
|
@ -298,9 +308,10 @@ def templated_attention(*args, **kwargs):
|
||||||
ROWS_GUARANTEED_SAFE=False,
|
ROWS_GUARANTEED_SAFE=False,
|
||||||
OUTPUT_LOGSUMEXP=True,
|
OUTPUT_LOGSUMEXP=True,
|
||||||
)
|
)
|
||||||
|
inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers)
|
||||||
return (
|
return (
|
||||||
autotune_select_algorithm(
|
autotune_select_algorithm(
|
||||||
"sdpa", choices, [query, key, value, logsumexp], layout
|
"sdpa", choices, inputs_for_autotuning, layout
|
||||||
),
|
),
|
||||||
logsumexp,
|
logsumexp,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,6 @@ class CachingAutotuner(KernelInterface):
|
||||||
compiled_binaries = []
|
compiled_binaries = []
|
||||||
if not self.configs:
|
if not self.configs:
|
||||||
raise RuntimeError("No triton configs are available")
|
raise RuntimeError("No triton configs are available")
|
||||||
|
|
||||||
for c in self.configs:
|
for c in self.configs:
|
||||||
try:
|
try:
|
||||||
compiled_binary, launcher = self._precompile_config(
|
compiled_binary, launcher = self._precompile_config(
|
||||||
|
|
@ -202,11 +201,8 @@ class CachingAutotuner(KernelInterface):
|
||||||
)
|
)
|
||||||
except OutOfResources as e:
|
except OutOfResources as e:
|
||||||
if len(self.configs) == 1:
|
if len(self.configs) == 1:
|
||||||
raise RuntimeError(
|
# There are no valid Triton configs
|
||||||
f"Failed to compile triton config: {c}. "
|
raise e
|
||||||
f"Report a fatal compilation error. "
|
|
||||||
f"{e}"
|
|
||||||
)
|
|
||||||
# Skip the config if we run out of resource
|
# Skip the config if we run out of resource
|
||||||
continue
|
continue
|
||||||
self.launchers.append(launcher)
|
self.launchers.append(launcher)
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,14 @@ from .codegen.triton_utils import config_of, signature_to_meta
|
||||||
from .exc import CUDACompileError
|
from .exc import CUDACompileError
|
||||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||||
from .runtime.runtime_utils import do_bench
|
from .runtime.runtime_utils import do_bench
|
||||||
from .utils import get_dtype_size, Placeholder, sympy_dot, sympy_product, unique
|
from .utils import (
|
||||||
|
get_dtype_size,
|
||||||
|
Placeholder,
|
||||||
|
sympy_dot,
|
||||||
|
sympy_index_symbol,
|
||||||
|
sympy_product,
|
||||||
|
unique,
|
||||||
|
)
|
||||||
from .virtualized import V
|
from .virtualized import V
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -269,20 +276,23 @@ class TritonTemplateKernel(TritonKernel):
|
||||||
potential multiple modifications
|
potential multiple modifications
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def add_input(name):
|
||||||
|
return self.args.input(name)
|
||||||
|
|
||||||
class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined]
|
class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined]
|
||||||
self.name = "PlaceholderSubstitution"
|
self.name = "PlaceholderSubstitution"
|
||||||
|
|
||||||
def load(self, name: str, index: sympy.Expr):
|
def load(self, name: str, index: sympy.Expr):
|
||||||
if name not in fixed_inputs:
|
if name not in fixed_inputs:
|
||||||
raise AssertionError(
|
# If it's not a fixed input, it's a load from a captured
|
||||||
f"All loads should be coming from fixed inputs - {name}"
|
# tensor
|
||||||
)
|
var = add_input(name)
|
||||||
|
return f"tl.load({var} + {index})"
|
||||||
|
|
||||||
return f"({fixed_inputs[name]})"
|
return f"({fixed_inputs[name]})"
|
||||||
|
|
||||||
# TODO Doesn't work yet
|
|
||||||
def indirect_indexing(self, index_var, size, check):
|
def indirect_indexing(self, index_var, size, check):
|
||||||
return self._inner.indirect_indexing(index_var, size, False)
|
return sympy_index_symbol(str(index_var))
|
||||||
# return sympy_symbol(str(index_var))
|
|
||||||
|
|
||||||
# if self.modification_cache is None:
|
# if self.modification_cache is None:
|
||||||
with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
|
with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
|
||||||
|
|
@ -589,16 +599,25 @@ class TritonTemplate(KernelTemplate):
|
||||||
+ "-"
|
+ "-"
|
||||||
)
|
)
|
||||||
mod = PyCodeCache.load(code, extra)
|
mod = PyCodeCache.load(code, extra)
|
||||||
_, call_args, _ = kernel.args.python_argdefs()
|
|
||||||
|
|
||||||
expected_args = list(unique(x.get_name() for x in input_nodes))
|
input_call_args = tuple(kernel.args.input_buffers.keys())
|
||||||
expected_args.extend([fake_out.get_name()])
|
output_call_args = tuple(kernel.args.output_buffers.keys())
|
||||||
assert list(call_args)[: len(expected_args)] == expected_args, (
|
|
||||||
call_args,
|
# We expect the input_buffer order to be [*input_nodes, *captured_buffers]
|
||||||
expected_args,
|
expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
|
||||||
|
expected_output_args = (fake_out.get_name(),)
|
||||||
|
assert input_call_args[: len(expected_input_args)] == expected_input_args, (
|
||||||
|
input_call_args,
|
||||||
|
expected_input_args,
|
||||||
)
|
)
|
||||||
|
assert output_call_args == expected_output_args, (
|
||||||
|
output_call_args,
|
||||||
|
expected_output_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
|
||||||
extra_args = V.graph.sizevars.size_hints(
|
extra_args = V.graph.sizevars.size_hints(
|
||||||
map(sympy.expand, call_args[len(expected_args) :]),
|
map(sympy.expand, tuple(kernel.args.sizevars.keys())),
|
||||||
fallback=config.unbacked_symint_fallback,
|
fallback=config.unbacked_symint_fallback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -636,13 +655,13 @@ class TritonTemplate(KernelTemplate):
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
||||||
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
|
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes),
|
||||||
output_tensor_meta=TensorMeta.from_irnodes(layout),
|
output_tensor_meta=TensorMeta.from_irnodes(layout),
|
||||||
)
|
)
|
||||||
|
|
||||||
return TritonTemplateCaller(
|
return TritonTemplateCaller(
|
||||||
kernel_hash_name,
|
kernel_hash_name,
|
||||||
input_nodes,
|
full_input_nodes,
|
||||||
layout,
|
layout,
|
||||||
make_kernel_render,
|
make_kernel_render,
|
||||||
extra.strip("-").replace("-", ", "),
|
extra.strip("-").replace("-", ", "),
|
||||||
|
|
@ -994,6 +1013,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
[c for c in choices if hasattr(c, "precompile")],
|
[c for c in choices if hasattr(c, "precompile")],
|
||||||
timeout=precompilation_timeout_seconds,
|
timeout=precompilation_timeout_seconds,
|
||||||
)
|
)
|
||||||
|
from triton.runtime.autotuner import OutOfResources
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def wait_on_futures():
|
def wait_on_futures():
|
||||||
|
|
@ -1013,6 +1033,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
)
|
)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
except OutOfResources:
|
||||||
|
# This config is invalid due to requiring too many resources
|
||||||
|
pass
|
||||||
|
|
||||||
executor.shutdown(wait=True)
|
executor.shutdown(wait=True)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user