pytorch/test/inductor/test_max_autotune.py
2025-04-27 09:56:42 +00:00

1715 lines
61 KiB
Python

# Owner(s): ["module: inductor"]
import contextlib
import json
import math
import os
import random
import tempfile
import unittest
from typing import Callable, Optional
import torch
from torch import multiprocessing as mp, nn
from torch._dynamo import reset
from torch._dynamo.exc import BackendCompilerFailed
from torch._dynamo.testing import rand_strided, reset_rng_state
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.autotune_process import (
_TestBenchmarkRequest,
CUDA_VISIBLE_DEVICES,
TuningProcess,
TuningProcessPool,
)
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import (
AlgorithmSelectorCache,
TritonTemplateCaller,
)
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_WINDOWS,
parametrize,
TEST_WITH_ROCM,
)
from torch.utils._triton import has_triton_tma_device
aten = torch.ops.aten
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu
from torch.testing._internal.inductor_utils import (
get_func_call,
get_kernel_launch,
GPU_TYPE,
HAS_CPU,
HAS_CUDA,
HAS_GPU,
)
torch.set_float32_matmul_precision("high")
if HAS_CUDA:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
def benchmark_choice(choice, args, out, expected_out, timings):
result = choice.benchmark(*args, out=out)
if expected_out is not None:
torch.testing.assert_close(out, expected_out)
timings.copy_(torch.tensor(result))
class FailChoiceCaller(ChoiceCaller):
def benchmark(self, *args, out):
raise RuntimeError("This choice caller will always throw")
@instantiate_parametrized_tests
class TestMaxAutotune(TestCase):
def _create_buffer(self, name, shape):
return Buffer(
name=name,
layout=FixedLayout(
torch.device(f"{GPU_TYPE}:0"), dtype=torch.float32, size=shape
),
)
# XPU have not support multiprocessing reduction in torch/multiprocessing/reductions.py
@skipIfXpu
def test_benchmark_choice_in_subproc(self):
gm = make_fx(
lambda: torch.zeros(2, 3)
)() # a dummy graph to construct the GraphLowering
graph = GraphLowering(gm)
# the graph handler is neede to create benchmark example value below
with V.set_graph_handler(graph):
buf1 = self._create_buffer("mat1", (2, 3))
buf2 = self._create_buffer("mat2", (3, 2))
buf3 = self._create_buffer("mat3", (2, 3))
buf4 = self._create_buffer("mat4", (3, 2))
layout = FixedLayout(torch.device(f"{GPU_TYPE}:0"), torch.float32, (2, 2))
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
out = AlgorithmSelectorCache.benchmark_example_value(layout)
# expected_out = (mat1 @ mat2) + (mat3 @ mat4)
expected_out = None
choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout)
# use a tensor since the mutation to a python list in a sub process
# is not synced back to the parent process
timings = torch.zeros(3, dtype=torch.float32)
ctx = mp.get_context("spawn")
child = ctx.Process(
target=benchmark_choice,
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
)
child.start()
child.join()
self.assertEqual(0, child.exitcode)
print(f"timings is {timings}, out {out}, expected_out {expected_out}")
# XPU have not support multiprocessing reduction in torch/multiprocessing/reductions.py
@skipIfXpu
def test_benchmark_choice_fail_in_subproc(self):
gm = make_fx(
lambda: torch.zeros(2, 3)
)() # a dummy graph to construct the GraphLowering
graph = GraphLowering(gm)
# the graph handler is neede to create benchmark example value below
with V.set_graph_handler(graph):
buf1 = self._create_buffer("mat1", (2, 3))
buf2 = self._create_buffer("mat2", (3, 2))
buf3 = self._create_buffer("mat3", (2, 3))
buf4 = self._create_buffer("mat4", (3, 2))
layout = FixedLayout(torch.device(f"{GPU_TYPE}:0"), torch.float32, (2, 2))
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
out = AlgorithmSelectorCache.benchmark_example_value(layout)
expected_out = (mat1 @ mat2) + (mat3 @ mat4)
choice = FailChoiceCaller("fail_choice_caller", [], None, description="")
# use a tensor since python list is not synced back
timings = torch.zeros(3, dtype=torch.float32)
ctx = mp.get_context("spawn")
child = ctx.Process(
target=benchmark_choice,
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
)
child.start()
child.join()
self.assertNotEqual(0, child.exitcode)
@parametrize("autotune_in_subproc", (True, False))
@parametrize("autotune_multi_device", (True, False))
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc, autotune_multi_device):
"""
This crash previously due to a triton issue: https://github.com/triton-lang/triton/issues/1298 .
With autotuning in subprocess, we don't crash anymore.
"""
m, n, k = 2048, 1536, 64
def mm_plus_mm(a, b, c, d):
return a @ b + c @ d
a = torch.randn(m, k).to(GPU_TYPE)
b = torch.randn(k, n).to(GPU_TYPE)
c = torch.randn(m, k).to(GPU_TYPE)
d = torch.randn(k, n).to(GPU_TYPE)
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": autotune_in_subproc,
"autotune_multi_device": autotune_multi_device,
}
):
torch.compile(mm_plus_mm)(a, b, c, d)
@parametrize("dynamic", (False, True))
def test_max_autotune_mm_plus_mm_zero_size_input(self, dynamic):
"""
Make sure autotuning mm_plus_mm with zero-size input works without crashes.
"""
m, n, k = 0, 1536, 64
def mm_plus_mm(a, b, c, d):
return a @ b + c @ d
a = torch.randn(m, k).to(GPU_TYPE)
b = torch.randn(k, n).to(GPU_TYPE)
c = torch.randn(m, k).to(GPU_TYPE)
d = torch.randn(k, n).to(GPU_TYPE)
with config.patch({"max_autotune": True}):
torch.compile(mm_plus_mm, dynamic=dynamic)(a, b, c, d)
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm(self, dynamic: bool):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(100, 10).to(GPU_TYPE)
b = torch.randn(10, 100).to(GPU_TYPE)
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
torch.compile(mm, dynamic=dynamic)(a, b)
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("a_transposed", (False, True))
@parametrize("b_transposed", (False, True))
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_persistent_tma(
self,
a_transposed: bool,
b_transposed: bool,
dynamic: bool,
):
def mm(a, b):
# TMA requires 16-byte alignment: here we repeat the dims
# by the factor of 8, as float16 is 2-byte. All dims are
# repeated due to the possible transpositions below.
a = a.repeat(8, 8)
b = b.repeat(8, 8)
if a_transposed:
a = a.T
if b_transposed:
b = b.T
return torch.mm(a, b)
M, N, K = 21, 31, 11
a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda()
b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda()
with config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
c_actual = torch.compile(mm, dynamic=dynamic)(a, b)
c_expected = mm(a, b)
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_persistent_tma_illegal_alignment(self, dynamic):
def mm(a, b):
return torch.mm(a, b)
M, N, K = 21, 31, 11
a = torch.randn(M, K).to(torch.float16).cuda()
b = torch.randn(K, N).to(torch.float16).cuda()
with self.assertRaises(BackendCompilerFailed) as context, config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
torch.compile(mm, dynamic=dynamic)(a, b)
# Lowering to the persistent+TMA Triton template should be skipped
# if any of the input inner dims are not 16-byte aligned. As a result,
# given the config flags above, we should have no choices left.
self.assertIn("NoValidChoicesError", str(context.exception))
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool):
"""
Make sure autotuning mm with zero-size input works without crashes.
"""
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(0, 10).to(GPU_TYPE)
b = torch.randn(10, 100).to(GPU_TYPE)
with config.patch({"max_autotune": True}):
torch.compile(mm, dynamic=dynamic)(a, b)
def test_precompilation_threads(self):
import threading
from typing import Any
from unittest.mock import Mock, patch
class FakeChoiceCaller(ChoiceCaller):
def __init__(self) -> None:
super().__init__("none", [], Mock(), description="")
self.thread_id = None
def precompile(self):
self.thread_id = threading.get_ident()
def call_name(self) -> str:
return None
def to_callable(self):
return None
def hash_key(self) -> str:
return str(hash(self))
def output_node(self) -> "TensorBox": # noqa: F821
return None
fake_choices = [FakeChoiceCaller() for i in range(10)]
fake_lookup_result = dict.fromkeys(fake_choices, 0.123)
def no_lookup(
choices: list[ChoiceCaller],
op: str,
inputs: str,
benchmark: Callable[[Any], dict[ChoiceCaller, float]],
) -> Optional[dict[ChoiceCaller, float]]:
if benchmark is not None:
return benchmark(choices)
asc = AlgorithmSelectorCache()
def fake_benchmark_fn(*args, **kwargs):
return fake_lookup_result
main_thread_id = threading.get_ident()
mock_debug_handler = Mock()
old_debug_handler = V.debug
try:
V.set_debug_handler(mock_debug_handler)
with patch.object(asc, "lookup", new=no_lookup):
with patch.object(
asc, "make_benchmark_fn", return_value=fake_benchmark_fn
):
with config.patch(
{
"autotune_in_subproc": False,
"compile_threads": len(fake_choices),
}
):
asc("test_call", fake_choices, [], Mock())
for fake_choice in fake_choices:
assert (
fake_choice.thread_id is not None
), "Expected all ChoiceCaller's precompile method to have been called"
assert (
fake_choice.thread_id != main_thread_id
), "Expected all ChoiceCaller's precompile method to have been called on separate thread"
finally:
V.set_debug_handler(old_debug_handler)
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm(self, dynamic=False):
"""
Make sure autotuning addmm in sub processes work without crashes.
"""
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def addmm(x, a, b):
return torch.addmm(x, a, b)
x = torch.randn(100).to(GPU_TYPE)
a = torch.randn(100, 10).to(GPU_TYPE)
b = torch.randn(10, 100).to(GPU_TYPE)
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
Y_compiled = torch.compile(addmm, dynamic=dynamic)(x, a, b)
Y = addmm(x, a, b)
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("a_transposed", (False, True))
@parametrize("b_transposed", (False, True))
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm_persistent_tma(
self,
a_transposed: bool,
b_transposed: bool,
dynamic: bool,
):
def addmm(x, a, b):
# TMA requires 16-byte alignment: here we repeat the dims
# by the factor of 8, as float16 is 2-byte. All dims are
# repeated due to the possible transpositions below.
x = x.repeat(8)
a = a.repeat(8, 8)
b = b.repeat(8, 8)
if a_transposed:
a = a.T
if b_transposed:
b = b.T
return torch.addmm(x, a, b)
M, N, K = 21, 31, 11
a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda()
b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda()
x = torch.randn(N).to(torch.float16).cuda()
with config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
c_actual = torch.compile(addmm, dynamic=dynamic)(x, a, b)
c_expected = addmm(x, a, b)
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm_persistent_tma_illegal_alignment(self, dynamic):
def addmm(x, a, b):
return torch.addmm(x, a, b)
M, N, K = 21, 31, 11
a = torch.randn(M, K).to(torch.float16).cuda()
b = torch.randn(K, N).to(torch.float16).cuda()
x = torch.randn(N).to(torch.float16).cuda()
with self.assertRaises(BackendCompilerFailed) as context, config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
torch.compile(addmm, dynamic=dynamic)(x, a, b)
# Lowering to the persistent+TMA Triton template should be skipped
# if any of the input inner dims are not 16-byte aligned. As a result,
# given the config flags above, we should have no choices left.
self.assertIn("NoValidChoicesError", str(context.exception))
@fresh_inductor_cache()
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA")
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("carveout", (None, 0, 27))
@parametrize("op", ("mm", "scaled_mm"))
def test_honor_sm_carveout_with_triton_tma(self, carveout, op: str):
def mm_func(a, b):
return torch.mm(a, b)
def scaled_mm(
a,
b,
scale_a,
scale_b,
):
return torch._scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
# Create large matrices to ensure we use all possible sms
size = 2560
a = torch.randn(size, size, device="cuda", dtype=torch.bfloat16)
b = (
torch.randn(size, size, device="cuda", dtype=torch.bfloat16)
.transpose(0, 1)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.tensor(1, dtype=torch.float32, device="cuda")
scale_b = torch.tensor(1, dtype=torch.float32, device="cuda")
args = (
(a.to(torch.float8_e4m3fn), b.to(torch.float8_e4m3fn), scale_a, scale_b)
if op == "scaled_mm"
else (a, b)
)
func = scaled_mm if op == "scaled_mm" else mm_func
# Set the specified carveout value
torch._C._set_sm_carveout_experimental(carveout)
if carveout is None:
self.assertIsNone(torch._C._get_sm_carveout_experimental())
else:
self.assertEqual(torch._C._get_sm_carveout_experimental(), carveout)
with config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": True,
"max_autotune_gemm_backends": "TRITON",
"test_configs.autotune_choice_name_regex": "tma",
}
):
compiled_mm = torch.compile(func, mode="max-autotune-no-cudagraphs")
compiled_mm(*args) # Warm-up compilation
with tempfile.NamedTemporaryFile() as f:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA]
) as prof:
# Run with the specified carveout
compiled_mm(*args)
# Export trace and analyze results
prof.export_chrome_trace(f.name)
# Extract grid sizes from the trace events for TMA kernels
kernel_name = "triton_tem_fused"
kernel_events = [
{
"grid": evt.get("args", {}).get("grid", []),
"grid_size": math.prod(evt.get("args", {}).get("grid", [])),
}
for evt in json.load(open(f.name))["traceEvents"]
if evt.get("cat", "") == "kernel"
and kernel_name in evt.get("name", "").lower()
]
# We should have exactly 1 kernel event for this run
self.assertEqual(
len(kernel_events),
1,
f"Expected exactly 1 kernel event, but got {len(kernel_events)}",
)
# Check that grid size matches expected values based on carveout
expected_grid_size = None
max_grid_size = torch.cuda.get_device_properties(
"cuda"
).multi_processor_count
careveout = 0 if carveout is None else carveout
expected_grid_size = max_grid_size - careveout
self.assertEqual(
kernel_events[0]["grid_size"],
expected_grid_size,
f"Grid size {kernel_events[0]['grid_size']} doesn't match {expected_grid_size} for carveout={carveout}",
)
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm_zero_size_input(self, dynamic):
"""
Make sure autotuning addmm with zero-size input works without crashes.
"""
def addmm(x, a, b):
return torch.addmm(x, a, b)
x = torch.randn(100).to(GPU_TYPE)
a = torch.randn(0, 10).to(GPU_TYPE)
b = torch.randn(10, 100).to(GPU_TYPE)
with config.patch({"max_autotune": True}):
torch.compile(addmm, dynamic=dynamic)(x, a, b)
def test_autotune_conv1x1(self):
# Assuming input has 3 channels and we want to produce 16 channels as output
conv1x1 = (
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1)
.to(memory_format=torch.channels_last)
.to(GPU_TYPE)
)
# Example input tensor: batch size = 4, channels = 3, height = 32, width = 32
# The memory format is set to `channels_last`
input_tensor = (
torch.randn(4, 3, 32, 32)
.contiguous(memory_format=torch.channels_last)
.to(GPU_TYPE)
)
with config.patch(
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
):
@torch.compile()
def foo(mod, x):
return mod(x)
with torch.no_grad():
out, code = run_and_get_code(foo, conv1x1, input_tensor)
FileCheck().check_not("extern_kernels.convolution").run(code[0])
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)
def test_filled_cache_precompile(self):
def fn(a, b, c):
a = (a @ b) @ c
a, b, c = (t.to(torch.float16) for t in [a, b, c])
return (a @ b) @ c
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
inputs = [torch.rand([256, 256], device=GPU_TYPE) for _ in range(3)]
from torch._dynamo.utils import counters
self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
torch._dynamo.reset()
counters.clear()
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
@fresh_inductor_cache()
@config.patch(search_autotune_cache=True)
def test_search_autotune_cache(self):
def fn(a, b, c):
a = (a @ b) @ c
a, b, c = (t.to(torch.float16) for t in [a, b, c])
return (a @ b) @ c
fn_c = torch.compile()(fn)
inputs = [torch.rand([256, 256], device=GPU_TYPE) for _ in range(3)]
from torch._dynamo.utils import counters
self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
@fresh_inductor_cache()
@config.patch(max_autotune=True, max_fusion_size=2)
def test_jit_fusion_matches_aot_fusion(self):
# In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due
# to proximity, we want to make sure AOT-compile pass does the same.
# AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end
# of the V.graph.buffers list because fuse(buf2, buf4) would have a
# better proximity score than fuse(buf1, buf2). This scenario is possible
# since finalizing MultiTemplateBuffers needs to replace buffers.
def fn(x, number):
buf0 = x + x
buf1 = number.item()
buf2 = x * x
buf3 = x @ x # MultiTemplateBuffer
buf4 = x**2
return buf0, buf1, buf2, buf3, buf4
inputs = (
torch.rand([256, 256], device=GPU_TYPE),
torch.tensor(3, device=GPU_TYPE),
)
torch._export.aot_compile(fn, args=inputs)
@config.patch(autotune_local_cache=False, autotune_remote_cache=False)
@runOnRocmArch(MI300_ARCH)
def test_precompilations(self):
def fn(a, b, c):
a = (a @ b) @ c
a, b, c = (t.to(torch.float16) for t in [a, b, c])
return (a @ b) @ c
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
inputs = [torch.rand([256, 256], device=GPU_TYPE) for _ in range(3)]
torch.testing.assert_close(fn_c(*inputs), fn(*inputs), atol=1e-2, rtol=1e-2)
from torch._dynamo.utils import counters
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 2)
def test_cat_addmm(self):
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
],
1,
)
args = [
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(4, 4, device=GPU_TYPE),
]
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "Triton",
}
):
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
def test_triton_template_with_epilogues_and_dynamic_shape(self):
def fn(
x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor, mul: torch.Tensor
) -> torch.Tensor:
return (
torch.nn.functional.relu(
torch.matmul(torch.transpose(x, 0, 1), torch.transpose(w, 0, 1))
+ bias
)
* mul
)
M0 = 5
M1 = 8
K = 4
N = 3
w = torch.rand(N, K).to(GPU_TYPE).half()
b = torch.rand(N).to(GPU_TYPE).half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "Triton",
}
):
compiled_fn = torch.compile(
fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs"
)
x0 = torch.rand(K, M0).to(GPU_TYPE).half()
mul0 = torch.rand(M0, N).to(GPU_TYPE).half()
y0 = compiled_fn(x0, w, b, mul0)
y0_expected = fn(x0, w, b, mul0)
torch.testing.assert_close(y0, y0_expected)
x1 = torch.rand(K, M1).to(GPU_TYPE).half()
mul1 = torch.rand(M1, N).to(GPU_TYPE).half()
y1 = compiled_fn(x1, w, b, mul1)
y1_expected = fn(x1, w, b, mul1)
torch.testing.assert_close(y1, y1_expected)
@config.patch(
benchmark_kernel=True,
fallback_random=True,
max_autotune_gemm=True,
)
@parametrize("device", ("cpu", GPU_TYPE))
def test_matmul_dropout(self, device):
def fwd(a, b):
x = a @ b
x = torch.nn.functional.dropout(x, 0.1)
return x
def fn(a, b):
x = fwd(a, b).sum()
x.backward()
return a.grad
N = 128
a = torch.randn(N, N, device=device, requires_grad=True)
b = torch.randn(N, N, device=device)
opt_fn = torch.compile(fn)
reset_rng_state()
ref = fn(a, b)
reset_rng_state()
act = opt_fn(a, b)
if N <= 8:
print(f"ref\n{ref}\nact\n{act}")
torch.testing.assert_close(ref, act, atol=1e-1, rtol=1e-1)
@config.patch(
max_autotune_gemm=True,
)
@unittest.skipIf(
getattr(torch, GPU_TYPE).device_count() < 2,
"Need at least 2 devices for this test",
)
def test_autotune_device_guard(self):
x = torch.randn(1024, 1024, device=f"{GPU_TYPE}:1")
y = torch.randn(1024, 1024, device=f"{GPU_TYPE}:1")
def f(x, y):
return x @ y
with fresh_inductor_cache():
act = torch.compile(f)(x, y)
ref = f(x, y)
self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3))
@config.patch(max_autotune=True)
def test_empty_conv_input(self, kernel_size=3):
x = torch.randn(0, 256, 14, 14, device=GPU_TYPE)
weight = torch.randn(256, 256, kernel_size, kernel_size, device=GPU_TYPE)
def f(x, weight):
return torch.convolution(
x,
weight,
bias=None,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
)
opt_f = torch.compile(f)
ref = f(x, weight)
act = opt_f(x, weight)
self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3))
@config.patch(max_autotune=True)
def test_empty_conv_input_with_1x1_kernel(self):
self.test_empty_conv_input(kernel_size=1)
@config.patch(max_autotune_gemm_backends="TRITON")
def test_baddmm(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(
torch.randn(64, 64, 192, dtype=torch.float16)
)
self.bias = torch.nn.Parameter(
torch.randn(64, 1, 192, dtype=torch.float16)
)
def forward(self, x):
return torch.ops.aten.baddbmm.default(self.bias, x, self.weight)
x = torch.randn(
64, 2048, 64, dtype=torch.float16, requires_grad=False, device=GPU_TYPE
)
mod = M().to(GPU_TYPE)
m_c = torch.compile(mode="max-autotune")(mod)
out, code = run_and_get_code(m_c, x)
self.assertEqual(out, mod(x))
FileCheck().check("triton_tem_fused_baddbmm").run(code[0])
@config.patch(max_autotune=True)
def test_conv1x1_with_free_symbols(self):
"""
Make sure there is no exception due to free symbols.
"""
conv = nn.Conv2d(
3, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False
).to(device=GPU_TYPE)
@torch.compile
def f(x, y, z):
h = y.nonzero().size(0)
w = z.nonzero().size(0)
x = x[:, :, :h, :w]
x = conv(x)
return x
x = torch.randn(4, 3, 224, 224).to(
memory_format=torch.channels_last, device=GPU_TYPE
)
for _ in range(2):
y = torch.randint(0, 10, (224,)).to(device=GPU_TYPE)
z = torch.randint(0, 10, (224,)).to(device=GPU_TYPE)
f(x, y, z)
def _test_cat_max_autotune_impl(self, using_triton_mm):
def f(x, y):
y = torch.cos(y)
x = torch.mm(x, x)
return torch.cat([x, y])
f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f)
inps = [
torch.randn(32, 32, device=GPU_TYPE),
torch.randn(32, 32, device=GPU_TYPE),
]
_, code = run_and_get_code(f_c, inps[0], inps[1])
self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25)
# mm kernel, and cos kernel
count = 2 if using_triton_mm else 1
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), count, exactly=True
).run(code[0])
def f(x, y):
y = torch.cos(y)
x = torch.mm(x, x)
out = torch.cat([x, y])
return out, x + 1
f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f)
_, code = run_and_get_code(f_c, inps[0], inps[1])
self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25)
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 2, exactly=True
).run(code[0])
def f(x, y):
y = torch.cos(y)
x = torch.mm(x, x)
return torch.cat([x, y]), torch.cat([y, x])
f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f)
self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25)
@config.patch({"test_configs.force_extern_kernel_in_multi_template": True})
def test_cat_max_autotune_extern(self):
self._test_cat_max_autotune_impl(using_triton_mm=False)
@skipIfXpu(
msg="The fusion not happend because it do not speedup on XPU, see issue #146568"
)
@config.patch(max_autotune_gemm_backends="TRITON")
def test_cat_max_autotune_triton(self):
self._test_cat_max_autotune_impl(using_triton_mm=True)
def test_conv_cat(self):
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
def forward(self, x):
x = self.conv(x)
return torch.cat((x, x + 1))
with torch.no_grad():
m = ToyModel().to(device=GPU_TYPE)
input_tensor = torch.randn(32, 3, 64, 64).to(device=GPU_TYPE)
# convolution is not currently plannable
m = torch.compile(m, mode="max-autotune-no-cudagraphs")
out, code = run_and_get_code(m, input_tensor)
self.assertEqual(out, m(input_tensor))
if not TEST_WITH_ROCM:
FileCheck().check("triton_poi_fused_cat_2.run").run(code[0])
def test_conv3d(self):
fn = torch.nn.functional.conv3d
image = torch.randn([1, 3, 8, 16, 32])
filt = torch.randn([3, 3, 7, 7, 7])
with config.patch({"max_autotune": True}):
expected = fn(image, filt)
actual = torch.compile(fn)(image, filt)
torch.testing.assert_close(actual, expected, atol=6e-5, rtol=0.001)
@config.patch(
max_autotune=True, max_autotune_conv_backends="", layout_optimization=False
)
def test_conv_backend(self):
m = torch.nn.Sequential(
torch.nn.Conv2d(3, 3, 1, 1),
).to(GPU_TYPE)
inp = torch.randn([2, 3, 16, 16]).to(GPU_TYPE)
with self.assertRaises(BackendCompilerFailed) as context:
torch.compile(m)(inp)
self.assertIn("NoValidChoicesError", str(context.exception))
def test_non_contiguous_input_mm(self):
"""
Make sure the triton template can work with non-contiguous inputs without crash.
Check https://github.com/pytorch/pytorch/issues/125437 for more details.
"""
x = rand_strided(
(50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
)
y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE)
@torch.compile(mode="max-autotune")
def f(x, y):
return x @ y
ref = x @ y
act = f(x, y)
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
def test_non_contiguous_input_addmm(self):
b = torch.randn((768), dtype=torch.bfloat16, device=GPU_TYPE)
x = rand_strided(
(50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
)
y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE)
@torch.compile(mode="max-autotune")
def f(x, y):
return torch.addmm(b, x, y)
ref = torch.addmm(b, x, y)
act = f(x, y)
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
def test_non_contiguous_input_bmm(self):
x = rand_strided(
(1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
)
y = rand_strided(
(1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE
)
@torch.compile(mode="max-autotune")
def f(x, y):
return torch.bmm(x, y)
ref = torch.bmm(x, y)
act = f(x, y)
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
# TODO: fix accuracy failure of the triton template on XPU.
# and enable this test case.
@skipIfXpu
def test_non_contiguous_input_mm_plus_mm(self):
x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE)
y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE)
x2 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE)
y2 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE)
@torch.compile(mode="max-autotune")
def f(x1, y1, x2, y2):
return x1 @ y1 + x2 @ y2
ref = x1 @ y1 + x2 @ y2
act = f(x1, y1, x2, y2)
torch.testing.assert_close(act, ref, atol=1e-2, rtol=1e-2)
@config.patch(
max_autotune=True,
max_autotune_gemm_backends="",
autotune_fallback_to_aten=False,
)
def test_no_valid_choices(self):
a = torch.zeros([2, 2], device=GPU_TYPE)
b = torch.zeros([2, 2], device=GPU_TYPE)
with self.assertRaises(BackendCompilerFailed) as context:
torch.compile(lambda a, b: a.matmul(b))(a, b)
self.assertIn("NoValidChoicesError", str(context.exception))
@parametrize("multi_template", (True, False))
@config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
autotune_fallback_to_aten=False,
)
def test_inf_timing(self, multi_template):
from unittest.mock import patch
lookup = AlgorithmSelectorCache.lookup
def mock_lookup(self, *args, **kwargs):
timings = lookup(self, *args, **kwargs)
return {choice: float("inf") for choice in timings.keys()}
a = torch.zeros([16, 16], device=GPU_TYPE)
b = torch.zeros([16, 16], device=GPU_TYPE)
with patch.object(AlgorithmSelectorCache, "lookup", mock_lookup), config.patch(
benchmark_epilogue_fusion=multi_template
):
with self.assertRaises(BackendCompilerFailed) as context:
torch.compile(lambda a, b: a.matmul(b))(a, b)
self.assertIn("NoValidChoicesError", str(context.exception))
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties().total_memory < 2e10,
"Only if the GPU has at least 20GB memory to be safe",
)
@config.patch(force_shape_pad=True, max_autotune=True)
def test_linear_and_cel(self):
"""
Similate a GPU without enough SMs. Make sure max-autotune still
works even when the MultiTritonTemplate encapsulates just extern
kernels.
"""
def mock_is_big_gpu(*args, **kwargs):
return False
B, T, C, V = 32, 1024, 768, 50257
linear = nn.Linear(C, V).bfloat16().to(device=GPU_TYPE)
ce = torch.nn.CrossEntropyLoss()
def f(x, y):
x.grad = None
linear.weight.grad = None
linear.bias.grad = None
loss = ce(linear(x), y)
loss.backward()
return loss
x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16()
x.retain_grad()
y = torch.randint(0, V, (B * T,)).cuda()
import torch._inductor.utils as inductor_utils
with unittest.mock.patch.object(inductor_utils, "is_big_gpu", mock_is_big_gpu):
opt_f = torch.compile(f)
expect = (f(x, y), x.grad, linear.weight.grad, linear.bias.grad)
actual = (opt_f(x, y), x.grad, linear.weight.grad, linear.bias.grad)
assert same(expect, actual, tol=1e-2), f"ref:\n{expect}\nact:\n{actual}"
@instantiate_parametrized_tests
class TestMaxAutotuneRemoteCache(TestCase):
def setUp(self):
super().setUp()
PatchCaches.setUp()
def tearDown(self):
super().tearDown()
PatchCaches.tearDown()
@parametrize("dynamic", (False, True))
@config.patch(
{"compile_threads": 1, "prologue_fusion": False}
) # Worker processes do not register PatchCaches() properly
def test_max_autotune_remote_caching(self, dynamic: bool):
from unittest.mock import patch
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(100, 10).to(GPU_TYPE)
b = torch.randn(10, 100).to(GPU_TYPE)
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
def f(x, y):
return Model()(x, y)
x = torch.randn(100, 100).to(GPU_TYPE)
y = torch.randn(100, 100).to(GPU_TYPE)
with config.patch(
{
"autotune_local_cache": False,
"autotune_remote_cache": True,
}
), patch.dict(os.environ), PatchCaches():
os.environ.pop("TRITON_CACHE_MANAGER", None)
with config.patch({"max_autotune": True}):
for _ in range(4):
with fresh_inductor_cache():
torch.compile(mm, dynamic=dynamic)(a, b)
reset()
with torch.compiler.config.patch(
{"cache_key_tag": "test"}
), fresh_inductor_cache():
torch.compile(mm, dynamic=dynamic)(a, b)
reset()
global_stats.report()
self.assertEqual(global_stats.autotune_remote, Stats(2, 3, 2))
global_stats.reset()
for _ in range(4):
with fresh_inductor_cache():
torch.compile(f, dynamic=dynamic)(x, y)
reset()
with torch.compiler.config.patch(
{"cache_key_tag": "test"}
), fresh_inductor_cache():
torch.compile(mm, dynamic=dynamic)(a, b)
reset()
global_stats.report()
self.assertEqual(global_stats.autotune_remote, Stats(2, 3, 2))
class _TestTritonTemplateCaller(TritonTemplateCaller):
def __init__(self, bmreq: _TestBenchmarkRequest):
self.bmreq = bmreq
def __str__(self) -> str:
return "test"
class TestTuningProcess(TestCase):
def check_healthy(self, p: TuningProcess, device: Optional[int] = None):
result = random.random()
bmreq = _TestBenchmarkRequest(result, device=device)
p.put(bmreq.benchmark)
self.assertEqual(p.get(), result)
def test_tuning_subproc_timeout(self):
p = TuningProcess(None)
bmreq = _TestBenchmarkRequest(0, sleep=120)
p.put(bmreq.benchmark)
with self.assertRaises(TimeoutError):
p.get(timeout=1.0)
# Make sure the TuningProcess is still usable after a timeout.
self.check_healthy(p)
p.shutdown()
def test_tuning_subproc_exception(self):
p = TuningProcess(None)
bmreq = _TestBenchmarkRequest(0, exc=RuntimeError("Fail"))
p.put(bmreq.benchmark)
with self.assertRaises(RuntimeError):
p.get()
# Make sure the TuningProcess is still usable after an exception.
self.check_healthy(p)
p.shutdown()
def test_tuning_subproc_crash(self):
p = TuningProcess(None)
bmreq = _TestBenchmarkRequest(0, crash=True)
p.put(bmreq.benchmark)
with self.assertRaises(EOFError):
p.get()
# Make sure the TuningProcess is still usable after a crash.
self.check_healthy(p)
p.shutdown()
def test_tuning_subproc_killed(self):
p = TuningProcess(None)
p.kill()
self.check_healthy(p)
p.shutdown()
def test_visible_devices(self):
device_list = TuningProcessPool.get_device_list()
for device in device_list:
p = TuningProcess(device)
self.check_healthy(p, device=device)
p.shutdown()
class TestTuningProcessPool(TestCase):
# Use only one device/subprocess so we test the process restarts
# and is usable after a crash.
@config.patch({"autotune_multi_device": False})
def test_tuning_pool_crash(self):
tuning_pool = TuningProcessPool()
# First force the tuning process to crash.
bmreq = _TestBenchmarkRequest(0, crash=True)
choice = _TestTritonTemplateCaller(bmreq)
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], float("inf"))
# Then send another request and make sure the sub-process
# has restarted and is operational.
bmreq = _TestBenchmarkRequest(3.14)
choice = _TestTritonTemplateCaller(bmreq)
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], bmreq.result)
tuning_pool.shutdown()
@config.patch({"autotune_multi_device": False})
def test_tuning_pool_timeout(self):
tuning_pool = TuningProcessPool()
# First force the tuning process to timeout.
bmreq = _TestBenchmarkRequest(0, sleep=120)
choice = _TestTritonTemplateCaller(bmreq)
with config.patch({"max_autotune_subproc_result_timeout_seconds": 1.0}):
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], float("inf"))
# Then send another request and make sure the sub-process
# has restarted and is operational.
bmreq = _TestBenchmarkRequest(3.14)
choice = _TestTritonTemplateCaller(bmreq)
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], bmreq.result)
tuning_pool.shutdown()
# XPU have to enable XPU_VISIBLE_DEVICES to control devices visibility.
@skipIfXpu
@config.patch({"autotune_multi_device": True})
def test_tuning_pool_multiple_devices(self):
# Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES
# is already set in the environment); use a subset of the available devices
# to ensure only the subset are visible to the sub-processes.
if CUDA_VISIBLE_DEVICES in os.environ:
visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",")
else:
visible_devices = [str(d) for d in range(torch.cuda.device_count())]
cuda_visible_devices = ",".join(visible_devices[-2:])
with unittest.mock.patch.dict(
os.environ, {CUDA_VISIBLE_DEVICES: cuda_visible_devices}
):
tuning_pool = TuningProcessPool()
choice1 = _TestTritonTemplateCaller(_TestBenchmarkRequest(3.14))
choice2 = _TestTritonTemplateCaller(_TestBenchmarkRequest(2.718))
timings = tuning_pool.benchmark([choice1, choice2])
self.assertEqual(timings[choice1], choice1.bmreq.result)
self.assertEqual(timings[choice2], choice2.bmreq.result)
tuning_pool.shutdown()
@instantiate_parametrized_tests
class TestPrologueFusion(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"max_autotune": True,
"prologue_fusion": True,
"benchmark_epilogue_fusion": False,
"shape_padding": False,
"max_autotune_gemm_backends": "TRITON",
"test_configs.max_mm_configs": 4, # significantly speeds up tests
}
)
)
def check_code(self, code_str, num_kernels, num_allocs, num_deallocs):
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(),
num_kernels,
exactly=True,
).run(code_str)
if num_allocs is not None:
FileCheck().check(get_func_call()).check_count(
"empty_strided", num_allocs, exactly=True
).run(code_str)
# skip the deallocation check when using cpp_wrapper; most deallocations happen
# outside of our control via RAIIAtenTensorHandle
if num_deallocs is not None and not config.cpp_wrapper:
FileCheck().check(get_func_call()).check_count(
"del", num_deallocs, exactly=True
).run(code_str)
@parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250)))
def test_upcast(self, sizes):
M, K, N = sizes
x = torch.rand([M, K], dtype=torch.float16, device=GPU_TYPE)
y = torch.rand([K, N], dtype=torch.float, device=GPU_TYPE)
def foo(x, y):
return x.to(y.dtype) @ y
out, code = run_and_get_code(torch.compile(foo), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
# upcast preserves zero mask
FileCheck().check("a =").check_not("tl.where").check("tl.dot").run(code[0])
@unittest.skip("Triton bug in compilation")
def test_gather_fusion(self):
M, K, N = (64, 128, 256)
x = torch.rand([M, K], dtype=torch.float16, device=GPU_TYPE)
y = torch.rand([K, N], dtype=torch.float16, device=GPU_TYPE)
index = torch.randperm(M, device=GPU_TYPE)
def foo(x, y, index):
return (x[index]) @ y
out, code = run_and_get_code(torch.compile(foo), x, y, index)
self.assertEqual(out, foo(x, y, index), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=3)
# should be done in low precision
(
FileCheck()
.check("for k_idx")
.check_not("to(tl.float32)")
.check("dot")
.run(code[0])
)
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
def test_low_precision(self):
M = K = N = 128
x = torch.rand([M, K], device=GPU_TYPE).to(torch.float8_e4m3fn)
y = torch.rand([K, N], dtype=torch.bfloat16, device=GPU_TYPE)
def foo(x, y):
return x.to(y.dtype) @ y
out, code = run_and_get_code(torch.compile(foo), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
# should be done in low precision, no arithmetic
(
FileCheck()
.check("for k_idx")
.check_not("to(tl.float32)")
.check("dot")
.run(code[0])
)
def foo(x, y):
return (x.to(y.dtype) + 1) @ y
out, code = run_and_get_code(torch.compile(foo), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
# should not be done in low precision
(
FileCheck()
.check("for k_idx")
.check("to(tl.float32)")
.check("dot")
.run(code[0])
)
def test_downcast(self):
# per heuristics, dont fuse a downcast into a mm because it would lead to more reads inside kernel
M, K, N = (64, 128, 256)
x = torch.rand([M, K], dtype=torch.float, device=GPU_TYPE)
y = torch.rand([K, N], dtype=torch.float16, device=GPU_TYPE)
def foo(x, y):
return x.to(y.dtype) @ y
out, code = run_and_get_code(torch.compile(foo), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=3)
@parametrize("sizes", ((64, 128, 256), (64, 64, 64), (64, 120, 64)))
def test_multiple_fusions(self, sizes):
M, K, N = sizes
def foo(x, y):
return ((x - 1.1) @ (y + 1.1)) * 1.1
x = torch.rand([M, K], dtype=torch.float, device=GPU_TYPE)
y = torch.rand([K, N], dtype=torch.float, device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(foo), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
# check that we do not CSE any variables between prologues, epilogues
FileCheck().check("def triton").check_count("= 1.1", 3, exactly=True).check(
"tl.store"
).run(code[0])
@config.patch(
{
"max_autotune_gemm_backends": "Triton",
"benchmark_epilogue_fusion": True,
"max_epilogue_benchmarked_choices": 3,
}
)
@skipIfXpu(
msg="The fusion not happend because it do not speedup on XPU, see issue #146568"
)
def test_pending_fusions_multiple(self):
def multi_use(x, y):
return (x @ x.T) * (y @ y.T)
x = torch.rand([128, 16], device=GPU_TYPE)
y = torch.rand([128, 32], device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(multi_use), x, y)
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 2, exactly=True
).run(code[0])
self.assertEqual(out, multi_use(x, y), atol=0.05, rtol=0.05)
def resolve_pending(x):
return (x @ x).relu()
x = torch.rand([128, 128], device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(resolve_pending), x)
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 1, exactly=True
).run(code[0])
self.assertEqual(out, resolve_pending(x), atol=0.05, rtol=0.05)
@config.patch(
{
"max_autotune_gemm_backends": "Triton",
"benchmark_epilogue_fusion": True,
"max_epilogue_benchmarked_choices": 3,
}
)
@skipIfXpu(
msg="The fusion not happend because it do not speedup on XPU, see issue #146568"
)
def test_pending_fusion_pro_and_epi(self):
def test_multiple_fusions(x):
y = x.to(torch.float)
return (y @ y).relu()
x = torch.rand([128, 128], dtype=torch.float16, device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(test_multiple_fusions), x)
FileCheck().check(get_func_call()).check_count(
get_kernel_launch(), 1, exactly=True
).run(code[0])
self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05)
@parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250)))
def test_multiple_inputs(self, sizes):
M, K, N = sizes
def foo(x, y, z):
return (x + y).to(torch.float) @ z
x = torch.rand([M, K], dtype=torch.float16, device=GPU_TYPE)
y = torch.rand([M, K], dtype=torch.float16, device=GPU_TYPE)
z = torch.rand([K, N], dtype=torch.float, device=GPU_TYPE)
out_eager = foo(x, y, z)
out, code = run_and_get_code(torch.compile(foo), x, y, z)
self.assertEqual(out, out_eager, atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=3)
def test_storage_offset_prologue(self):
def foo(a):
q = a[:64, :]
k = a[64:, :]
return torch.mm(q + 2, k - 2)
inp = torch.randn(128, 64, device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(foo), inp)
self.assertEqual(out, foo(inp), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=1)
@config.patch(realize_reads_threshold=1, realize_opcount_threshold=1)
@parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250)))
def test_prologue_multiple_nodes(self, sizes):
M, K, N = sizes
def foo(x, y):
return ((((x * 2) - 1) / 2) @ (y * 4)) * 3.0
x = torch.rand([M, K], dtype=torch.float, device=GPU_TYPE)
y = torch.rand([K, N], dtype=torch.float, device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(foo), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
@parametrize("K", (63, 64))
def test_broadcast_x(self, K):
def foo(x, y):
return (x.expand([1, y.shape[0]]) + 1) @ y
x = torch.rand([1, 1], dtype=torch.float, device=GPU_TYPE)
y = torch.rand([K, 128], dtype=torch.float, device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(foo, dynamic=True), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
def test_broadcast_y(self):
def foo(x, y):
return x @ y
M = 20
N = K = 1
x = torch.rand([M, K], dtype=torch.float, device=GPU_TYPE)
y = torch.rand([K, N], dtype=torch.float, device=GPU_TYPE)
torch._dynamo.mark_dynamic(x, 0)
out, code = run_and_get_code(torch.compile(foo, dynamic=True), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
def test_preserves_zero_analysis(self):
fns = (
(lambda x: x.relu(), False), # preserves zero
(lambda x: x + 1, True), # does not
(
lambda x: torch.hypot(x, x),
True,
), # not handled in analysis, conservatively assume does not preserve
)
def foo(x, y, fn):
return fn(x) @ y
for fn, should_mask in fns:
x = torch.rand([64, 127], dtype=torch.float, device=GPU_TYPE)
y = torch.rand([127, 64], dtype=torch.float, device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(foo), x, y, fn)
self.assertEqual(out, foo(x, y, fn), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
if should_mask:
f = FileCheck().check("k_idx").check("a =").check_same("tl.where")
else:
f = FileCheck().check("k_idx").check("a =").check_not("tl.where")
f.check("tl.dot").run(code[0])
@config.patch(realize_reads_threshold=1, realize_opcount_threshold=1)
@parametrize("benchmark_fusion", (True, False))
def test_prologue_read_into_both_inputs(self, benchmark_fusion):
M = K = 256
# not supported today. it could be, but typically the pointwise nodes would get
# inlined into separate nodes.
def foo(x):
y = (x + 1) * 2
return y @ (y - 2)
with config.patch(benchmark_epilogue_fusion=benchmark_fusion):
x = torch.rand([M, K], dtype=torch.float, device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(foo), x)
self.assertEqual(out, foo(x), atol=0.05, rtol=0.05)
# not guaranteed to fuse, but still checking correctness
if not benchmark_fusion:
self.check_code(
code[0], num_kernels=2, num_allocs=None, num_deallocs=None
)
@config.patch(realize_reads_threshold=1, realize_opcount_threshold=1)
@config.patch(allow_buffer_reuse=False)
def test_mismatched_prologue_group(self):
def foo(x, y, z):
a = (x + 2) * 2
b = a * y
return b @ z
x = torch.rand([1, 256], device=GPU_TYPE)
y = torch.rand([256, 256], device=GPU_TYPE)
z = torch.rand([256, 128], device=GPU_TYPE)
out, code = run_and_get_code(torch.compile(foo), x, y, z)
self.assertEqual(out, foo(x, y, z), atol=0.05, rtol=0.05)
# theres one more dealloc than there should be because of a buffer reuse. TODO:
# not sure why disabling buffer reuse doesnt stop
self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=4)
# XPU have not enabled pad_mm in fx_passes, so there is always one kernel.
@skipIfXpu
@config.patch(shape_padding=True)
@config.patch(force_shape_pad=True)
@parametrize("sizes", ((250, 245, 128), (250, 256, 128), (256, 128, 62)))
def test_prologue_masked_load(self, sizes):
M, K, N = sizes
def foo(x, y):
return x @ y
x = torch.rand([250, 245], device=GPU_TYPE)
y = torch.rand([245, 128], device=GPU_TYPE)
# we should not attempt prologue fusion if it turns an aligned load
# into an unaligned load
out, code = run_and_get_code(torch.compile(foo), x, y)
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
self.check_code(code[0], num_kernels=3, num_allocs=3, num_deallocs=4)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
# Set env to make it work in CI.
if HAS_GPU and HAS_CPU and is_big_gpu():
run_tests()