mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129763 Approved by: https://github.com/jansel
882 lines
31 KiB
Python
882 lines
31 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import json
|
|
import os
|
|
import unittest
|
|
from typing import Callable, List, Optional
|
|
|
|
import torch
|
|
from torch import multiprocessing as mp, nn
|
|
from torch._dynamo import reset
|
|
from torch._dynamo.exc import BackendCompilerFailed
|
|
from torch._dynamo.testing import rand_strided, reset_rng_state
|
|
from torch._inductor import config
|
|
from torch._inductor.autotune_process import (
|
|
BenchmarkRequest,
|
|
CUDA_VISIBLE_DEVICES,
|
|
TuningProcessPool,
|
|
)
|
|
from torch._inductor.graph import GraphLowering
|
|
from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout
|
|
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
|
|
from torch._inductor.select_algorithm import (
|
|
AlgorithmSelectorCache,
|
|
TritonTemplateCaller,
|
|
)
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
|
|
from torch._inductor.virtualized import V
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skipIfRocm,
|
|
)
|
|
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
|
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
if HAS_CUDA:
|
|
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
|
|
|
_CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/")
|
|
|
|
|
|
def _get_path_without_sccache() -> str:
|
|
"""
|
|
Get the PATH environment variable without sccache.
|
|
"""
|
|
path_envs = os.environ.get("PATH", "").split(":")
|
|
path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
|
|
return ":".join(path_envs)
|
|
|
|
|
|
def benchmark_choice(choice, args, out, expected_out, timings):
|
|
result = choice.benchmark(*args, out=out)
|
|
if expected_out is not None:
|
|
torch.testing.assert_close(out, expected_out)
|
|
|
|
timings.copy_(torch.tensor(result))
|
|
|
|
|
|
class FailChoiceCaller(ChoiceCaller):
|
|
def benchmark(self, *args, out):
|
|
raise RuntimeError("This choice caller will always throw")
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestMaxAutotune(TestCase):
|
|
def _create_buffer(self, name, shape):
|
|
return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape))
|
|
|
|
def test_benchmark_choice_in_subproc(self):
|
|
gm = make_fx(
|
|
lambda: torch.zeros(2, 3)
|
|
)() # a dummy graph to construct the GraphLowering
|
|
graph = GraphLowering(gm)
|
|
|
|
# the graph handler is neede to create benchmark example value below
|
|
with V.set_graph_handler(graph):
|
|
buf1 = self._create_buffer("mat1", (2, 3))
|
|
buf2 = self._create_buffer("mat2", (3, 2))
|
|
buf3 = self._create_buffer("mat3", (2, 3))
|
|
buf4 = self._create_buffer("mat4", (3, 2))
|
|
|
|
layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
|
|
|
|
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
|
|
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
|
|
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
|
|
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
|
|
|
|
out = AlgorithmSelectorCache.benchmark_example_value(layout)
|
|
# expected_out = (mat1 @ mat2) + (mat3 @ mat4)
|
|
expected_out = None
|
|
|
|
choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout)
|
|
# use a tensor since the mutation to a python list in a sub process
|
|
# is not synced back to the parent process
|
|
timings = torch.zeros(3, dtype=torch.float32)
|
|
ctx = mp.get_context("spawn")
|
|
child = ctx.Process(
|
|
target=benchmark_choice,
|
|
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
|
|
)
|
|
child.start()
|
|
child.join()
|
|
self.assertEqual(0, child.exitcode)
|
|
print(f"timings is {timings}, out {out}, expected_out {expected_out}")
|
|
|
|
def test_benchmark_choice_fail_in_subproc(self):
|
|
gm = make_fx(
|
|
lambda: torch.zeros(2, 3)
|
|
)() # a dummy graph to construct the GraphLowering
|
|
graph = GraphLowering(gm)
|
|
|
|
# the graph handler is neede to create benchmark example value below
|
|
with V.set_graph_handler(graph):
|
|
buf1 = self._create_buffer("mat1", (2, 3))
|
|
buf2 = self._create_buffer("mat2", (3, 2))
|
|
buf3 = self._create_buffer("mat3", (2, 3))
|
|
buf4 = self._create_buffer("mat4", (3, 2))
|
|
|
|
layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
|
|
|
|
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
|
|
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
|
|
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
|
|
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
|
|
|
|
out = AlgorithmSelectorCache.benchmark_example_value(layout)
|
|
expected_out = (mat1 @ mat2) + (mat3 @ mat4)
|
|
|
|
choice = FailChoiceCaller("fail_choice_caller", [], None)
|
|
|
|
# use a tensor since python list is not synced back
|
|
timings = torch.zeros(3, dtype=torch.float32)
|
|
ctx = mp.get_context("spawn")
|
|
child = ctx.Process(
|
|
target=benchmark_choice,
|
|
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
|
|
)
|
|
child.start()
|
|
child.join()
|
|
self.assertNotEqual(0, child.exitcode)
|
|
|
|
@parametrize("autotune_in_subproc", (True, False))
|
|
@parametrize("autotune_multi_device", (True, False))
|
|
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc, autotune_multi_device):
|
|
"""
|
|
This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 .
|
|
With autotuning in subprocess, we don't crash anymore.
|
|
"""
|
|
m, n, k = 2048, 1536, 64
|
|
|
|
def mm_plus_mm(a, b, c, d):
|
|
return a @ b + c @ d
|
|
|
|
a = torch.randn(m, k).cuda()
|
|
b = torch.randn(k, n).cuda()
|
|
c = torch.randn(m, k).cuda()
|
|
d = torch.randn(k, n).cuda()
|
|
|
|
with config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"autotune_in_subproc": autotune_in_subproc,
|
|
"autotune_multi_device": autotune_multi_device,
|
|
}
|
|
):
|
|
torch.compile(mm_plus_mm)(a, b, c, d)
|
|
|
|
@parametrize("dynamic", (False, True))
|
|
def test_max_autotune_mm_plus_mm_zero_size_input(self, dynamic):
|
|
"""
|
|
Make sure autotuning mm_plus_mm with zero-size input works without crashes.
|
|
"""
|
|
m, n, k = 0, 1536, 64
|
|
|
|
def mm_plus_mm(a, b, c, d):
|
|
return a @ b + c @ d
|
|
|
|
a = torch.randn(m, k).cuda()
|
|
b = torch.randn(k, n).cuda()
|
|
c = torch.randn(m, k).cuda()
|
|
d = torch.randn(k, n).cuda()
|
|
|
|
with config.patch({"max_autotune": True}):
|
|
torch.compile(mm_plus_mm, dynamic=dynamic)(a, b, c, d)
|
|
|
|
@parametrize("dynamic", (False, True))
|
|
def test_max_autotune_regular_mm(self, dynamic: bool):
|
|
"""
|
|
Make sure autotuning mm in sub processes work without crashes.
|
|
"""
|
|
|
|
def mm(a, b):
|
|
a = torch.sin(a)
|
|
return a @ b
|
|
|
|
a = torch.randn(100, 10).cuda()
|
|
b = torch.randn(10, 100).cuda()
|
|
|
|
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
|
|
torch.compile(mm, dynamic=dynamic)(a, b)
|
|
|
|
@parametrize("dynamic", (False, True))
|
|
def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool):
|
|
"""
|
|
Make sure autotuning mm with zero-size input works without crashes.
|
|
"""
|
|
|
|
def mm(a, b):
|
|
a = torch.sin(a)
|
|
return a @ b
|
|
|
|
a = torch.randn(0, 10).cuda()
|
|
b = torch.randn(10, 100).cuda()
|
|
|
|
with config.patch({"max_autotune": True}):
|
|
torch.compile(mm, dynamic=dynamic)(a, b)
|
|
|
|
@skipIfRocm
|
|
@parametrize("dynamic", (False, True))
|
|
def test_max_autotune_remote_caching(self, dynamic: bool):
|
|
from unittest.mock import patch
|
|
|
|
def mm(a, b):
|
|
a = torch.sin(a)
|
|
return a @ b
|
|
|
|
a = torch.randn(100, 10).cuda()
|
|
b = torch.randn(10, 100).cuda()
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
def f(x, y):
|
|
return Model()(x, y)
|
|
|
|
x = torch.randn(100, 100).cuda()
|
|
y = torch.randn(100, 100).cuda()
|
|
|
|
cache = {}
|
|
num_get = 0
|
|
num_put = 0
|
|
|
|
class MyCache:
|
|
def __init__(self, key, is_autotune=False):
|
|
pass
|
|
|
|
def get(self, filename):
|
|
nonlocal cache
|
|
nonlocal num_get
|
|
if filename not in cache:
|
|
return None
|
|
ret = json.loads(cache[filename])
|
|
num_get += 1
|
|
return ret
|
|
|
|
def put(self, filename, data):
|
|
nonlocal cache
|
|
nonlocal num_put
|
|
cache[filename] = json.dumps(data)
|
|
num_put += 1
|
|
|
|
cache_module = (
|
|
"triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend"
|
|
if config.is_fbcode()
|
|
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
|
)
|
|
|
|
with config.patch(
|
|
{
|
|
"autotune_local_cache": False,
|
|
"autotune_remote_cache": True,
|
|
}
|
|
), patch.dict(os.environ), patch(cache_module, MyCache, create=True):
|
|
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
|
with config.patch({"max_autotune": True}):
|
|
for _ in range(4):
|
|
with fresh_inductor_cache():
|
|
torch.compile(mm, dynamic=dynamic)(a, b)
|
|
reset()
|
|
self.assertEqual(num_get, 3)
|
|
self.assertEqual(num_put, 1)
|
|
num_get = 0
|
|
num_put = 0
|
|
for _ in range(4):
|
|
with fresh_inductor_cache():
|
|
torch.compile(f, dynamic=dynamic)(x, y)
|
|
reset()
|
|
self.assertEqual(num_get, 3)
|
|
self.assertEqual(num_put, 1)
|
|
|
|
@skipIfRocm
|
|
def test_precompilation_threads(self):
|
|
import threading
|
|
from typing import Any, Dict
|
|
from unittest.mock import Mock, patch
|
|
|
|
class FakeChoiceCaller(ChoiceCaller):
|
|
def __init__(self):
|
|
super().__init__("none", [], Mock())
|
|
self.thread_id = None
|
|
|
|
def precompile(self):
|
|
self.thread_id = threading.get_ident()
|
|
|
|
def call_name(self) -> str:
|
|
return None
|
|
|
|
def to_callable(self):
|
|
return None
|
|
|
|
def hash_key(self) -> str:
|
|
return str(hash(self))
|
|
|
|
def output_node(self) -> "TensorBox": # noqa: F821
|
|
return None
|
|
|
|
fake_choices = [FakeChoiceCaller() for i in range(10)]
|
|
fake_lookup_result = dict.fromkeys(fake_choices, 0.123)
|
|
|
|
def no_lookup(
|
|
choices: List[ChoiceCaller],
|
|
op: str,
|
|
inputs: str,
|
|
benchmark: Callable[[Any], Dict[ChoiceCaller, float]],
|
|
) -> Dict[ChoiceCaller, float]:
|
|
if benchmark is not None:
|
|
return benchmark(choices)
|
|
|
|
asc = AlgorithmSelectorCache()
|
|
|
|
def fake_benchmark_fn(*args, **kwargs):
|
|
return fake_lookup_result
|
|
|
|
main_thread_id = threading.get_ident()
|
|
mock_debug_handler = Mock()
|
|
old_debug_handler = V.debug
|
|
try:
|
|
V.set_debug_handler(mock_debug_handler)
|
|
with patch.object(asc, "lookup", new=no_lookup):
|
|
with patch.object(
|
|
asc, "make_benchmark_fn", return_value=fake_benchmark_fn
|
|
):
|
|
with config.patch(
|
|
{
|
|
"autotune_in_subproc": False,
|
|
"compile_threads": len(fake_choices),
|
|
}
|
|
):
|
|
asc("test_call", fake_choices, [], Mock())
|
|
for fake_choice in fake_choices:
|
|
assert (
|
|
fake_choice.thread_id is not None
|
|
), "Expected all ChoiceCaller's precompile method to have been called"
|
|
assert (
|
|
fake_choice.thread_id != main_thread_id
|
|
), "Expected all ChoiceCaller's precompile method to have been called on separate thread"
|
|
finally:
|
|
V.set_debug_handler(old_debug_handler)
|
|
|
|
@parametrize("dynamic", (False, True))
|
|
def test_max_autotune_addmm(self, dynamic=False):
|
|
"""
|
|
Make sure autotuning addmm in sub processes work without crashes.
|
|
"""
|
|
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
|
|
|
def addmm(x, a, b):
|
|
return torch.addmm(x, a, b)
|
|
|
|
x = torch.randn(100).cuda()
|
|
a = torch.randn(100, 10).cuda()
|
|
b = torch.randn(10, 100).cuda()
|
|
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
|
|
Y_compiled = torch.compile(addmm, dynamic=dynamic)(x, a, b)
|
|
Y = addmm(x, a, b)
|
|
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
|
|
|
|
@parametrize("dynamic", (False, True))
|
|
def test_max_autotune_addmm_zero_size_input(self, dynamic):
|
|
"""
|
|
Make sure autotuning addmm with zero-size input works without crashes.
|
|
"""
|
|
|
|
def addmm(x, a, b):
|
|
return torch.addmm(x, a, b)
|
|
|
|
x = torch.randn(100).cuda()
|
|
a = torch.randn(0, 10).cuda()
|
|
b = torch.randn(10, 100).cuda()
|
|
with config.patch({"max_autotune": True}):
|
|
torch.compile(addmm, dynamic=dynamic)(x, a, b)
|
|
|
|
@skipIfRocm
|
|
def test_autotune_conv1x1(self):
|
|
# Assuming input has 3 channels and we want to produce 16 channels as output
|
|
conv1x1 = (
|
|
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1)
|
|
.to(memory_format=torch.channels_last)
|
|
.cuda()
|
|
)
|
|
|
|
# Example input tensor: batch size = 4, channels = 3, height = 32, width = 32
|
|
# The memory format is set to `channels_last`
|
|
input_tensor = (
|
|
torch.randn(4, 3, 32, 32)
|
|
.contiguous(memory_format=torch.channels_last)
|
|
.cuda()
|
|
)
|
|
|
|
with config.patch(
|
|
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
|
|
):
|
|
|
|
@torch.compile()
|
|
def foo(mod, x):
|
|
return mod(x)
|
|
|
|
with torch.no_grad():
|
|
out, code = run_and_get_code(foo, conv1x1, input_tensor)
|
|
|
|
FileCheck().check_not("extern_kernels.convolution").run(code[0])
|
|
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)
|
|
|
|
@skipIfRocm
|
|
def test_filled_cache_precompile(self):
|
|
def fn(a, b, c):
|
|
a = (a @ b) @ c
|
|
a, b, c = (t.to(torch.float16) for t in [a, b, c])
|
|
return (a @ b) @ c
|
|
|
|
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
|
|
inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)]
|
|
from torch._dynamo.utils import counters
|
|
|
|
self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
|
|
|
|
torch._dynamo.reset()
|
|
counters.clear()
|
|
|
|
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
|
|
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
|
|
|
|
@skipIfRocm
|
|
@fresh_inductor_cache()
|
|
@config.patch(search_autotune_cache=True)
|
|
def test_search_autotune_cache(self):
|
|
def fn(a, b, c):
|
|
a = (a @ b) @ c
|
|
a, b, c = (t.to(torch.float16) for t in [a, b, c])
|
|
return (a @ b) @ c
|
|
|
|
fn_c = torch.compile()(fn)
|
|
inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)]
|
|
from torch._dynamo.utils import counters
|
|
|
|
self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
|
|
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
|
|
|
|
@skipIfRocm
|
|
@fresh_inductor_cache()
|
|
@config.patch(max_autotune=True, max_fusion_size=2)
|
|
def test_jit_fusion_matches_aot_fusion(self):
|
|
# In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due
|
|
# to proximity, we want to make sure AOT-compile pass does the same.
|
|
# AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end
|
|
# of the V.graph.buffers list because fuse(buf2, buf4) would have a
|
|
# better proximity score than fuse(buf1, buf2). This scenario is possible
|
|
# since finalizing MultiTemplateBuffers needs to replace buffers.
|
|
def fn(x, number):
|
|
buf0 = x + x
|
|
buf1 = number.item()
|
|
buf2 = x * x
|
|
buf3 = x @ x # MultiTemplateBuffer
|
|
buf4 = x**2
|
|
return buf0, buf1, buf2, buf3, buf4
|
|
|
|
inputs = (torch.rand([256, 256], device="cuda"), torch.tensor(3, device="cuda"))
|
|
torch._export.aot_compile(fn, args=inputs)
|
|
|
|
@config.patch(autotune_local_cache=False, autotune_remote_cache=False)
|
|
@skipIfRocm
|
|
def test_precompilations(self):
|
|
def fn(a, b, c):
|
|
a = (a @ b) @ c
|
|
a, b, c = (t.to(torch.float16) for t in [a, b, c])
|
|
return (a @ b) @ c
|
|
|
|
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
|
|
inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)]
|
|
|
|
self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
|
|
|
|
from torch._dynamo.utils import counters
|
|
|
|
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 2)
|
|
|
|
def test_cat_addmm(self):
|
|
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
|
|
return torch.cat(
|
|
[
|
|
torch.addmm(a, b, c),
|
|
torch.addmm(b, c, a),
|
|
],
|
|
1,
|
|
)
|
|
|
|
args = [
|
|
torch.randn(4, 4, device="cuda"),
|
|
torch.randn(4, 4, device="cuda"),
|
|
torch.randn(4, 4, device="cuda"),
|
|
]
|
|
with config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "Triton",
|
|
}
|
|
):
|
|
expected = fn(*args)
|
|
actual = torch.compile(fn)(*args)
|
|
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
|
|
|
|
def test_triton_template_with_epilogues_and_dynamic_shape(self):
|
|
def fn(
|
|
x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor, mul: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return (
|
|
torch.nn.functional.relu(
|
|
torch.matmul(torch.transpose(x, 0, 1), torch.transpose(w, 0, 1))
|
|
+ bias
|
|
)
|
|
* mul
|
|
)
|
|
|
|
M0 = 5
|
|
M1 = 8
|
|
K = 4
|
|
N = 3
|
|
w = torch.rand(N, K).cuda().half()
|
|
b = torch.rand(N).cuda().half()
|
|
|
|
with config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"autotune_in_subproc": True,
|
|
"max_autotune_gemm_backends": "Triton",
|
|
}
|
|
):
|
|
compiled_fn = torch.compile(
|
|
fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs"
|
|
)
|
|
|
|
x0 = torch.rand(K, M0).cuda().half()
|
|
mul0 = torch.rand(M0, N).cuda().half()
|
|
y0 = compiled_fn(x0, w, b, mul0)
|
|
y0_expected = fn(x0, w, b, mul0)
|
|
torch.testing.assert_close(y0, y0_expected)
|
|
|
|
x1 = torch.rand(K, M1).cuda().half()
|
|
mul1 = torch.rand(M1, N).cuda().half()
|
|
y1 = compiled_fn(x1, w, b, mul1)
|
|
y1_expected = fn(x1, w, b, mul1)
|
|
torch.testing.assert_close(y1, y1_expected)
|
|
|
|
@config.patch(
|
|
benchmark_kernel=True,
|
|
fallback_random=True,
|
|
max_autotune_gemm=True,
|
|
)
|
|
@parametrize("device", ("cpu", "cuda"))
|
|
def test_matmul_dropout(self, device):
|
|
def fwd(a, b):
|
|
x = a @ b
|
|
x = torch.nn.functional.dropout(x, 0.1)
|
|
return x
|
|
|
|
def fn(a, b):
|
|
x = fwd(a, b).sum()
|
|
x.backward()
|
|
return a.grad
|
|
|
|
N = 128
|
|
a = torch.randn(N, N, device=device, requires_grad=True)
|
|
b = torch.randn(N, N, device=device)
|
|
|
|
opt_fn = torch.compile(fn)
|
|
reset_rng_state()
|
|
ref = fn(a, b)
|
|
reset_rng_state()
|
|
act = opt_fn(a, b)
|
|
|
|
if N <= 8:
|
|
print(f"ref\n{ref}\nact\n{act}")
|
|
torch.testing.assert_close(ref, act, atol=1e-1, rtol=1e-1)
|
|
|
|
@config.patch(
|
|
max_autotune_gemm=True,
|
|
)
|
|
@unittest.skipIf(
|
|
torch.cuda.device_count() < 2, "Need at least 2 devices for this test"
|
|
)
|
|
def test_autotune_device_guard(self):
|
|
x = torch.randn(1024, 1024, device="cuda:1")
|
|
y = torch.randn(1024, 1024, device="cuda:1")
|
|
|
|
def f(x, y):
|
|
return x @ y
|
|
|
|
with fresh_inductor_cache():
|
|
act = torch.compile(f)(x, y)
|
|
ref = f(x, y)
|
|
self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3))
|
|
|
|
@config.patch(max_autotune=True)
|
|
def test_empty_conv_input(self, kernel_size=3):
|
|
x = torch.randn(0, 256, 14, 14, device="cuda")
|
|
weight = torch.randn(256, 256, kernel_size, kernel_size, device="cuda")
|
|
|
|
def f(x, weight):
|
|
return torch.convolution(
|
|
x,
|
|
weight,
|
|
bias=None,
|
|
stride=[1, 1],
|
|
padding=[0, 0],
|
|
dilation=[1, 1],
|
|
transposed=False,
|
|
output_padding=[0, 0],
|
|
groups=1,
|
|
)
|
|
|
|
opt_f = torch.compile(f)
|
|
ref = f(x, weight)
|
|
act = opt_f(x, weight)
|
|
self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3))
|
|
|
|
@config.patch(max_autotune=True)
|
|
def test_empty_conv_input_with_1x1_kernel(self):
|
|
self.test_empty_conv_input(kernel_size=1)
|
|
|
|
@config.patch(max_autotune=True)
|
|
def test_conv1x1_with_free_symbols(self):
|
|
"""
|
|
Make sure there is no exception due to free symbols.
|
|
"""
|
|
conv = nn.Conv2d(
|
|
3, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False
|
|
).to(device="cuda")
|
|
|
|
@torch.compile
|
|
def f(x, y, z):
|
|
h = y.nonzero().size(0)
|
|
w = z.nonzero().size(0)
|
|
x = x[:, :, :h, :w]
|
|
x = conv(x)
|
|
return x
|
|
|
|
x = torch.randn(4, 3, 224, 224).to(
|
|
memory_format=torch.channels_last, device="cuda"
|
|
)
|
|
for _ in range(2):
|
|
y = torch.randint(0, 10, (224,)).to(device="cuda")
|
|
z = torch.randint(0, 10, (224,)).to(device="cuda")
|
|
f(x, y, z)
|
|
|
|
def test_conv3d(self):
|
|
fn = torch.nn.functional.conv3d
|
|
image = torch.randn([1, 3, 8, 16, 32])
|
|
filt = torch.randn([3, 3, 7, 7, 7])
|
|
|
|
with config.patch({"max_autotune": True}):
|
|
expected = fn(image, filt)
|
|
actual = torch.compile(fn)(image, filt)
|
|
torch.testing.assert_close(actual, expected, atol=6e-5, rtol=0.001)
|
|
|
|
def test_non_contiguous_input_mm(self):
|
|
"""
|
|
Make sure the triton template can work with non-contiguous inputs without crash.
|
|
Check https://github.com/pytorch/pytorch/issues/125437 for more details.
|
|
"""
|
|
x = rand_strided(
|
|
(50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda")
|
|
|
|
@torch.compile(mode="max-autotune")
|
|
def f(x, y):
|
|
return x @ y
|
|
|
|
ref = x @ y
|
|
act = f(x, y)
|
|
self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
|
|
|
|
def test_non_contiguous_input_addmm(self):
|
|
b = torch.randn((768), dtype=torch.bfloat16, device="cuda")
|
|
x = rand_strided(
|
|
(50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda")
|
|
|
|
@torch.compile(mode="max-autotune")
|
|
def f(x, y):
|
|
return torch.addmm(b, x, y)
|
|
|
|
ref = torch.addmm(b, x, y)
|
|
act = f(x, y)
|
|
self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
|
|
|
|
def test_non_contiguous_input_bmm(self):
|
|
x = rand_strided(
|
|
(1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
y = rand_strided(
|
|
(1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
|
|
@torch.compile(mode="max-autotune")
|
|
def f(x, y):
|
|
return torch.bmm(x, y)
|
|
|
|
ref = torch.bmm(x, y)
|
|
act = f(x, y)
|
|
self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
|
|
|
|
def test_non_contiguous_input_mm_plus_mm(self):
|
|
x1 = rand_strided((50257, 32768), (1, 50304), device="cuda")
|
|
y1 = rand_strided((32768, 768), (768, 1), device="cuda")
|
|
|
|
x2 = rand_strided((50257, 32768), (1, 50304), device="cuda")
|
|
y2 = rand_strided((32768, 768), (768, 1), device="cuda")
|
|
|
|
@torch.compile(mode="max-autotune")
|
|
def f(x1, y1, x2, y2):
|
|
return x1 @ y1 + x2 @ y2
|
|
|
|
ref = x1 @ y1 + x2 @ y2
|
|
act = f(x1, y1, x2, y2)
|
|
self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
|
|
|
|
@config.patch(
|
|
max_autotune=True,
|
|
max_autotune_gemm_backends="",
|
|
autotune_fallback_to_aten=False,
|
|
)
|
|
def test_no_valid_choices(self):
|
|
a = torch.zeros([2, 2], device="cuda")
|
|
b = torch.zeros([2, 2], device="cuda")
|
|
with self.assertRaises(BackendCompilerFailed) as context:
|
|
torch.compile(lambda a, b: a.matmul(b))(a, b)
|
|
self.assertIn("NoValidChoicesError", str(context.exception))
|
|
|
|
@parametrize("multi_template", (True, False))
|
|
@config.patch(
|
|
max_autotune=True,
|
|
max_autotune_gemm_backends="TRITON",
|
|
autotune_fallback_to_aten=False,
|
|
)
|
|
def test_inf_timing(self, multi_template):
|
|
from unittest.mock import patch
|
|
|
|
lookup = AlgorithmSelectorCache.lookup
|
|
|
|
def mock_lookup(self, *args, **kwargs):
|
|
timings = lookup(self, *args, **kwargs)
|
|
return {choice: float("inf") for choice in timings.keys()}
|
|
|
|
a = torch.zeros([16, 16], device="cuda")
|
|
b = torch.zeros([16, 16], device="cuda")
|
|
with patch.object(AlgorithmSelectorCache, "lookup", mock_lookup), config.patch(
|
|
benchmark_epilogue_fusion=multi_template
|
|
):
|
|
with self.assertRaises(BackendCompilerFailed) as context:
|
|
torch.compile(lambda a, b: a.matmul(b))(a, b)
|
|
self.assertIn("NoValidChoicesError", str(context.exception))
|
|
|
|
|
|
class TestBenchmarkRequest(BenchmarkRequest):
|
|
def __init__(
|
|
self, value: float, multi_device: bool, parent_visible_devices: Optional[str]
|
|
) -> None:
|
|
self.value = value
|
|
self.multi_device = multi_device
|
|
self.parent_visible_devices = parent_visible_devices
|
|
|
|
def benchmark(
|
|
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
|
|
) -> float:
|
|
# Verify that the visible devices env var is set correctly. If multi-device
|
|
# auto-tuning is disabled, the visible devices should be unmanipulated from
|
|
# the parent process. If multi-device auto-tuning is enabled, the visible
|
|
# devices should be a _single_ valid device number. Note that we can't perform
|
|
# this validation directly from the test body because benchmarks execute in a
|
|
# separate process. If the check fails, however, the test will detect the
|
|
# failure by virtue of not receiving the expected result back.
|
|
visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES)
|
|
if not self.multi_device:
|
|
assert visible_devices == self.parent_visible_devices
|
|
else:
|
|
valid_devices = self.parent_visible_devices.split(",")
|
|
assert visible_devices in valid_devices
|
|
|
|
return self.value
|
|
|
|
|
|
class TestTritonTemplateCaller(TritonTemplateCaller):
|
|
def __init__(self, bmreq: TestBenchmarkRequest):
|
|
self.bmreq = bmreq
|
|
|
|
def __str__(self) -> str:
|
|
return "test"
|
|
|
|
|
|
class TestTuningProcess(TestCase):
|
|
def test_tuning_pool_crash(self):
|
|
# Use only one device/subprocess so we test the process restarts
|
|
# and is usable after a "crash".
|
|
with config.patch({"autotune_multi_device": False}):
|
|
tuning_pool = TuningProcessPool()
|
|
tuning_pool.initialize()
|
|
|
|
# First force the tuning process to "crash" by setting a bogus
|
|
# string for the expected visible devices.
|
|
bmreq = TestBenchmarkRequest(3.14, False, "invalid")
|
|
choice = TestTritonTemplateCaller(bmreq)
|
|
|
|
timings = tuning_pool.benchmark([choice])
|
|
self.assertTrue(choice in timings)
|
|
self.assertEqual(timings[choice], float("inf"))
|
|
|
|
# Then send another request and make sure the sub-process
|
|
# has restarted and is operational. 'valid_devices' expected
|
|
# to be None because autotune_multi_device is off.
|
|
choice.bmreq.parent_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES)
|
|
|
|
timings = tuning_pool.benchmark([choice])
|
|
self.assertTrue(choice in timings)
|
|
self.assertEqual(timings[choice], bmreq.value)
|
|
|
|
tuning_pool.terminate()
|
|
|
|
def test_tuning_pool_multiple_devices(self):
|
|
with config.patch({"autotune_multi_device": True}):
|
|
# Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES
|
|
# is already set in the environment); use a subset of the available devices
|
|
# to ensure only the subset are visible to the sub-processes.
|
|
if CUDA_VISIBLE_DEVICES in os.environ:
|
|
visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",")
|
|
else:
|
|
visible_devices = [str(d) for d in range(torch.cuda.device_count())]
|
|
|
|
parent_visible_devices = ",".join(visible_devices[-2:])
|
|
os.environ[CUDA_VISIBLE_DEVICES] = parent_visible_devices
|
|
|
|
tuning_pool = TuningProcessPool()
|
|
tuning_pool.initialize()
|
|
|
|
choice1 = TestTritonTemplateCaller(
|
|
TestBenchmarkRequest(3.14, True, parent_visible_devices),
|
|
)
|
|
choice2 = TestTritonTemplateCaller(
|
|
TestBenchmarkRequest(2.718, True, parent_visible_devices),
|
|
)
|
|
|
|
timings = tuning_pool.benchmark([choice1, choice2])
|
|
self.assertEqual(timings[choice1], choice1.bmreq.value)
|
|
self.assertEqual(timings[choice2], choice2.bmreq.value)
|
|
|
|
tuning_pool.terminate()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.utils import is_big_gpu
|
|
|
|
# Set env to make it work in CI.
|
|
if HAS_CUDA and HAS_CPU and is_big_gpu(0):
|
|
run_tests()
|