Add support for capturing tensors with score_mod (#124444)

```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')

import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel

index = torch.ops.aten
torch.manual_seed(0)

B = 16
H = 16
S = 2048
D = 64

head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
    return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')

query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)

compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">

Differential Revision: [D56583900](https://our.internmc.facebook.com/intern/diff/D56583900)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
This commit is contained in:
chilli 2024-04-25 10:35:47 -07:00 committed by PyTorch MergeBot
parent 3a810bcf91
commit 7321005dd8
8 changed files with 221 additions and 73 deletions

View File

@ -4,7 +4,7 @@ import functools
from collections import namedtuple from collections import namedtuple
from typing import Callable from typing import Callable
from unittest import expectedFailure, skipUnless from unittest import skip, skipUnless
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -36,6 +36,8 @@ supported_platform = skipUnless(
Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
index = torch.ops.aten.index
def create_attention(score_mod): def create_attention(score_mod):
return functools.partial(_templated_attention, score_mod=score_mod) return functools.partial(_templated_attention, score_mod=score_mod)
@ -47,6 +49,8 @@ test_dtypes = (
else [torch.float16, torch.float32] else [torch.float16, torch.float32]
) )
test_dtypes_fast = [torch.float16]
# TODO float16 was causing ERRORs for tests on ROCm # TODO float16 was causing ERRORs for tests on ROCm
# See https://github.com/pytorch/pytorch/issues/123531 # See https://github.com/pytorch/pytorch/issues/123531
if common_utils.TEST_WITH_ROCM: if common_utils.TEST_WITH_ROCM:
@ -65,13 +69,19 @@ def _causal_mod(score, b, h, token_q, token_kv):
return torch.where(token_q >= token_kv, score, float("-inf")) return torch.where(token_q >= token_kv, score, float("-inf"))
B = 4
H = 8
S = 2048
D = 64
class TestTemplatedSDPA(InductorTestCase): class TestTemplatedSDPA(InductorTestCase):
def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16): def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):
sdpa_partial = create_attention(score_mod) sdpa_partial = create_attention(score_mod)
compiled_sdpa = torch.compile(sdpa_partial) compiled_sdpa = torch.compile(sdpa_partial)
q = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda") q = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda") k = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda") v = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out = sdpa_partial( golden_out = sdpa_partial(
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64) q.to(torch.float64), k.to(torch.float64), v.to(torch.float64)
) )
@ -109,23 +119,116 @@ class TestTemplatedSDPA(InductorTestCase):
self.run_test(composed_score_mod, dtype) self.run_test(composed_score_mod, dtype)
# TODO We are currently not capturing free variables in the closure correctly
@expectedFailure
@supported_platform @supported_platform
@common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("dtype", test_dtypes)
def test_captured_buffers(self, dtype: torch.dtype): def test_captured_buffers(self, dtype: torch.dtype):
head_offset = torch.rand(8, device="cuda", dtype=dtype) head_offset = torch.rand(H, device="cuda", dtype=dtype)
def score_mod(score, b, h, m, n): def score_mod(score, b, h, m, n):
return score + head_offset[h] return score + index(head_offset, [h])
self.run_test(score_mod, dtype) self.run_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_seq_masking(self, dtype):
seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
seq_idx[S // 2 :] = 1
def seq_mask_mod(score, b, h, q, kv):
return torch.where(
index(seq_idx, [q]) == index(seq_idx, [kv]), score, float("-inf")
)
self.run_test(seq_mask_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_seq_only(self, dtype):
bias = torch.randn(S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [q, kv])
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_seq_batch(self, dtype):
bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [b, q, kv])
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_head_seq_batch(self, dtype):
bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [b, h, q, kv])
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_rel_bias(self, dtype):
rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(rel_bias, [(q - kv) + S])
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_dependent_causal_bidirectional(self, dtype):
num_bidirectional = torch.randint(0, S, (B,), device="cuda", dtype=torch.int32)
def bias_mod(score, b, h, q, kv):
causal_attention = q >= kv
cur_num_bidirectional = index(num_bidirectional, (b,))
bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
kv <= cur_num_bidirectional
)
return torch.where(
bidirectional_attention_on_video | causal_attention,
score,
-float("inf"),
)
self.run_test(bias_mod, dtype)
@supported_platform
@skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
@common_utils.parametrize("dtype", test_dtypes)
def test_njt_causal(self, dtype):
offsets = torch.tensor(
[0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
)
seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
for idx in range(len(offsets) - 1):
seq_idx[offsets[idx] : offsets[idx + 1]] = idx
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
def njt_score_mod(qk, b, h, q, kv):
q_nested = q - index(offsets, [index(seq_idx, [q])])
kv_nested = kv - index(offsets, [index(seq_idx, [kv])])
return orig_score_mod(qk, b, h, q_nested, kv_nested)
return njt_score_mod
causal_njt = create_njt_wrapper(_causal_mod, offsets, seq_idx)
self.run_test(causal_njt, dtype)
@supported_platform @supported_platform
def test_backwards_fails(self): def test_backwards_fails(self):
make_tensor = functools.partial( make_tensor = functools.partial(
torch.randn, torch.randn,
(4, 8, 2048, 64), (B, H, S, D),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
@ -139,9 +242,9 @@ class TestTemplatedSDPA(InductorTestCase):
@supported_platform @supported_platform
def test_mixed_dtypes_fails(self): def test_mixed_dtypes_fails(self):
query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda") query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
key = torch.randn((1, 1, 2048, 64), dtype=torch.float16, device="cuda") key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
value = torch.randn((1, 1, 2048, 64), dtype=torch.float16, device="cuda") value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Expected query, key, and value to have the same dtype" ValueError, "Expected query, key, and value to have the same dtype"
): ):
@ -163,6 +266,21 @@ class TestTemplatedSDPA(InductorTestCase):
self.run_test(score_mod) self.run_test(score_mod)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune_with_captured(self):
head_scale = torch.randn(H, device="cuda")
batch_scale = torch.randn(B, device="cuda")
tok_scale = torch.randn(S, device="cuda")
def bias_mod(score, batch, head, token_q, token_kv):
score = score + index(tok_scale, [token_q])
score = score + index(batch_scale, [batch])
score = score + index(head_scale, [head])
return score
self.run_test(bias_mod)
@supported_platform @supported_platform
@common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", [_identity, _causal]) @common_utils.parametrize("score_mod", [_identity, _causal])
@ -173,7 +291,7 @@ class TestTemplatedSDPA(InductorTestCase):
make_tensor = functools.partial( make_tensor = functools.partial(
torch.randn, torch.randn,
(4, 8, 2048, 64), (B, H, S, D),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
@ -215,7 +333,7 @@ class TestTemplatedSDPA(InductorTestCase):
def test_logsumexp_only_return(self): def test_logsumexp_only_return(self):
make_tensor = functools.partial( make_tensor = functools.partial(
torch.randn, torch.randn,
(4, 8, 2048, 64), (B, H, S, D),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
@ -236,7 +354,7 @@ class TestTemplatedSDPA(InductorTestCase):
def test_logsumexp_is_not_fused(self): def test_logsumexp_is_not_fused(self):
make_tensor = functools.partial( make_tensor = functools.partial(
torch.randn, torch.randn,
(4, 8, 2048, 64), (B, H, S, D),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,

View File

@ -1536,12 +1536,10 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
) -> "VariableTracker": ) -> "VariableTracker":
from .builder import wrap_fx_proxy from .builder import wrap_fx_proxy
query, key, value, score_mod, *other_buffers = self.normalize_to_args( query, key, value, score_mod = self.normalize_to_args(args, kwargs)
args, kwargs
)
p_args, p_kwargs = self.create_wrapped_node(tx, query, score_mod) p_args, p_kwargs = self.create_wrapped_node(tx, query, score_mod)
proxied_args = [query, key, value, *other_buffers] proxied_args = [query, key, value]
# Store the invocation as a call # Store the invocation as a call
# Norm_kwargs contains the score_function and we dont want to proxy this because # Norm_kwargs contains the score_function and we dont want to proxy this because

View File

@ -60,7 +60,7 @@ def math_attention(
""" """
assert len(other_buffers) == 0, "Other buffers are not yet supported." assert len(other_buffers) == 0, "Other buffers are not yet supported."
scores = query @ key.transpose(-2, -1) scores = (query @ key.transpose(-2, -1)).to(dtype=torch.float32)
b = torch.arange(0, scores.size(0), device=scores.device) b = torch.arange(0, scores.size(0), device=scores.device)
h = torch.arange(0, scores.size(1), device=scores.device) h = torch.arange(0, scores.size(1), device=scores.device)
@ -179,9 +179,11 @@ def templated_attention_functionalize(
assert isinstance(other_buffers_unwrapped, tuple) assert isinstance(other_buffers_unwrapped, tuple)
assert all(isinstance(item, torch.Tensor) for item in other_buffers_unwrapped) assert all(isinstance(item, torch.Tensor) for item in other_buffers_unwrapped)
example_vals = [torch.zeros((), dtype=query.dtype)] + [ example_vals = (
torch.zeros((), dtype=torch.int) for _ in range(4) [torch.zeros((), dtype=query.dtype)]
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
+ list(other_buffers_unwrapped)
)
with ctx.redispatch_to_next() as m: with ctx.redispatch_to_next() as m:
functional_score_mod = ctx.functionalize(score_mod) functional_score_mod = ctx.functionalize(score_mod)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch

View File

@ -3413,22 +3413,14 @@ class TritonScheduling(BaseScheduling):
buffer_names.update(node.used_buffer_names()) buffer_names.update(node.used_buffer_names())
# Get buffers objects # Get buffers objects
def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
if name in V.graph.name_to_buffer:
return V.graph.name_to_buffer[name]
elif name in V.graph.graph_inputs:
return V.graph.graph_inputs[name]
elif name in V.graph.constants:
data = V.graph.constants[name]
return ir.ConstantBuffer(
name,
ir.FixedLayout(
data.device, data.dtype, *V.graph.static_sizes_strides(data)
),
)
raise RuntimeError(f"Failed to find buffer matching name {name}")
buffers = [_get_buffer(name) for name in buffer_names] def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
buf = V.graph.get_buffer(name)
if buf is None:
raise RuntimeError(f"Failed to find buffer matching name {name}")
return buf
buffers = [V.graph.get_buffer(name) for name in buffer_names]
# In theory we can separately check xnumel and rnumel are <= int_max # In theory we can separately check xnumel and rnumel are <= int_max
# but some indexers do use the full linear index so we need to be # but some indexers do use the full linear index so we need to be

View File

@ -665,6 +665,14 @@ class GraphLowering(torch.fx.Interpreter):
return self.name_to_buffer[buffer_name] return self.name_to_buffer[buffer_name]
if buffer_name in self.graph_inputs: if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name] return self.graph_inputs[buffer_name]
if buffer_name in self.constants:
data = V.graph.constants[buffer_name]
return ir.ConstantBuffer(
buffer_name,
ir.FixedLayout(
data.device, data.dtype, *V.graph.static_sizes_strides(data)
),
)
return None return None
def get_dtype(self, buffer_name: str): def get_dtype(self, buffer_name: str):

View File

@ -3,6 +3,7 @@ import logging
from typing import Any, List from typing import Any, List
import torch import torch
from .. import config
from ..lowering import empty_strided, lowerings, register_lowering from ..lowering import empty_strided, lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate from ..select_algorithm import autotune_select_algorithm, TritonTemplate
@ -114,12 +115,14 @@ sdpa_template = TritonTemplate(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk) qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = offs_m[:, None]
n = start_n + offs_n[None, :]
{{ modification( {{ modification(
score="qk", score="qk",
b="off_hz // H", b="off_hz // H",
h="off_hz % H", h="off_hz % H",
m="offs_m[:, None]", m="m",
n="start_n + offs_n[None, :]", n="n",
out="qk" out="qk"
) | indent_except_first(2) }} ) | indent_except_first(2) }}
# TODO: In the case that score_mod is linear, this can be LICMed # TODO: In the case that score_mod is linear, this can be LICMed
@ -170,7 +173,8 @@ sdpa_template = TritonTemplate(
) )
@register_lowering(torch.ops.higher_order.templated_attention) # TODO: We probably also need a layout constraint?
@register_lowering(torch.ops.higher_order.templated_attention, type_promotion_kind=None)
def templated_attention(*args, **kwargs): def templated_attention(*args, **kwargs):
from torch._prims_common import make_contiguous_strides_for from torch._prims_common import make_contiguous_strides_for
from ..ir import ( from ..ir import (
@ -182,7 +186,7 @@ def templated_attention(*args, **kwargs):
TensorBox, TensorBox,
) )
query, key, value, subgraph = args query, key, value, subgraph, *other_buffers = args
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer: def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
return TensorBox.create( return TensorBox.create(
@ -272,17 +276,23 @@ def templated_attention(*args, **kwargs):
configs: List[Any] = [] configs: List[Any] = []
if query.get_dtype() == torch.float32: if query.get_dtype() == torch.float32:
configs.append((64, 64, 4, 3)) configs.append((64, 64, 4, 3))
else:
configs.append((128, 64, 4, 3))
if config.max_autotune:
configs += [ configs += [
(128, 64, 4, 3), (128, 64, 4, 3),
(128, 128, 4, 3), (128, 128, 4, 3),
(128, 128, 8, 2), (128, 128, 8, 2),
(64, 128, 4, 3), (64, 128, 4, 3),
(64, 64, 4, 3),
] ]
# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
# We do need to explicitly pass it in for autotuning though.
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
sdpa_template.maybe_append_choice( sdpa_template.maybe_append_choice(
choices=choices, choices=choices,
input_nodes=(query, key, value, logsumexp), input_nodes=[query, key, value, logsumexp],
layout=layout, layout=layout,
subgraphs=subgraph_buffer, subgraphs=subgraph_buffer,
mutated_inputs=[ mutated_inputs=[
@ -298,9 +308,10 @@ def templated_attention(*args, **kwargs):
ROWS_GUARANTEED_SAFE=False, ROWS_GUARANTEED_SAFE=False,
OUTPUT_LOGSUMEXP=True, OUTPUT_LOGSUMEXP=True,
) )
inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers)
return ( return (
autotune_select_algorithm( autotune_select_algorithm(
"sdpa", choices, [query, key, value, logsumexp], layout "sdpa", choices, inputs_for_autotuning, layout
), ),
logsumexp, logsumexp,
) )

View File

@ -194,7 +194,6 @@ class CachingAutotuner(KernelInterface):
compiled_binaries = [] compiled_binaries = []
if not self.configs: if not self.configs:
raise RuntimeError("No triton configs are available") raise RuntimeError("No triton configs are available")
for c in self.configs: for c in self.configs:
try: try:
compiled_binary, launcher = self._precompile_config( compiled_binary, launcher = self._precompile_config(
@ -202,11 +201,8 @@ class CachingAutotuner(KernelInterface):
) )
except OutOfResources as e: except OutOfResources as e:
if len(self.configs) == 1: if len(self.configs) == 1:
raise RuntimeError( # There are no valid Triton configs
f"Failed to compile triton config: {c}. " raise e
f"Report a fatal compilation error. "
f"{e}"
)
# Skip the config if we run out of resource # Skip the config if we run out of resource
continue continue
self.launchers.append(launcher) self.launchers.append(launcher)

View File

@ -36,7 +36,14 @@ from .codegen.triton_utils import config_of, signature_to_meta
from .exc import CUDACompileError from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType from .ir import ChoiceCaller, PrimitiveInfoType
from .runtime.runtime_utils import do_bench from .runtime.runtime_utils import do_bench
from .utils import get_dtype_size, Placeholder, sympy_dot, sympy_product, unique from .utils import (
get_dtype_size,
Placeholder,
sympy_dot,
sympy_index_symbol,
sympy_product,
unique,
)
from .virtualized import V from .virtualized import V
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -269,20 +276,23 @@ class TritonTemplateKernel(TritonKernel):
potential multiple modifications potential multiple modifications
""" """
def add_input(name):
return self.args.input(name)
class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined]
self.name = "PlaceholderSubstitution" self.name = "PlaceholderSubstitution"
def load(self, name: str, index: sympy.Expr): def load(self, name: str, index: sympy.Expr):
if name not in fixed_inputs: if name not in fixed_inputs:
raise AssertionError( # If it's not a fixed input, it's a load from a captured
f"All loads should be coming from fixed inputs - {name}" # tensor
) var = add_input(name)
return f"tl.load({var} + {index})"
return f"({fixed_inputs[name]})" return f"({fixed_inputs[name]})"
# TODO Doesn't work yet
def indirect_indexing(self, index_var, size, check): def indirect_indexing(self, index_var, size, check):
return self._inner.indirect_indexing(index_var, size, False) return sympy_index_symbol(str(index_var))
# return sympy_symbol(str(index_var))
# if self.modification_cache is None: # if self.modification_cache is None:
with V.set_ops_handler(PlaceholderSubstitution(V.ops)): with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
@ -589,16 +599,25 @@ class TritonTemplate(KernelTemplate):
+ "-" + "-"
) )
mod = PyCodeCache.load(code, extra) mod = PyCodeCache.load(code, extra)
_, call_args, _ = kernel.args.python_argdefs()
expected_args = list(unique(x.get_name() for x in input_nodes)) input_call_args = tuple(kernel.args.input_buffers.keys())
expected_args.extend([fake_out.get_name()]) output_call_args = tuple(kernel.args.output_buffers.keys())
assert list(call_args)[: len(expected_args)] == expected_args, (
call_args, # We expect the input_buffer order to be [*input_nodes, *captured_buffers]
expected_args, expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
expected_output_args = (fake_out.get_name(),)
assert input_call_args[: len(expected_input_args)] == expected_input_args, (
input_call_args,
expected_input_args,
) )
assert output_call_args == expected_output_args, (
output_call_args,
expected_output_args,
)
full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
extra_args = V.graph.sizevars.size_hints( extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, call_args[len(expected_args) :]), map(sympy.expand, tuple(kernel.args.sizevars.keys())),
fallback=config.unbacked_symint_fallback, fallback=config.unbacked_symint_fallback,
) )
@ -636,13 +655,13 @@ class TritonTemplate(KernelTemplate):
num_stages=num_stages, num_stages=num_stages,
num_warps=num_warps, num_warps=num_warps,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
input_tensor_meta=TensorMeta.from_irnodes(input_nodes), input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes),
output_tensor_meta=TensorMeta.from_irnodes(layout), output_tensor_meta=TensorMeta.from_irnodes(layout),
) )
return TritonTemplateCaller( return TritonTemplateCaller(
kernel_hash_name, kernel_hash_name,
input_nodes, full_input_nodes,
layout, layout,
make_kernel_render, make_kernel_render,
extra.strip("-").replace("-", ", "), extra.strip("-").replace("-", ", "),
@ -994,6 +1013,7 @@ class AlgorithmSelectorCache(PersistentCache):
[c for c in choices if hasattr(c, "precompile")], [c for c in choices if hasattr(c, "precompile")],
timeout=precompilation_timeout_seconds, timeout=precompilation_timeout_seconds,
) )
from triton.runtime.autotuner import OutOfResources
@functools.lru_cache(None) @functools.lru_cache(None)
def wait_on_futures(): def wait_on_futures():
@ -1013,6 +1033,9 @@ class AlgorithmSelectorCache(PersistentCache):
) )
except StopIteration: except StopIteration:
pass pass
except OutOfResources:
# This config is invalid due to requiring too many resources
pass
executor.shutdown(wait=True) executor.shutdown(wait=True)