pytorch/test/inductor/test_select_algorithm.py
Jason Ansel d6dd2e97da [inductor] Rewrite Triton templates + epilogue fusion (retry) (#91105)
https://github.com/pytorch/pytorch/pull/90738 seems a bit borked. ghimport fails on it, and I unlinked it from the Phabricator diff, but it still won't land.  This is an exact copy that PR without using ghstack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91105
Approved by: https://github.com/ngimel
2022-12-20 02:38:23 +00:00

147 lines
4.5 KiB
Python

# Owner(s): ["module: inductor"]
import functools
import logging
from unittest.mock import patch
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.select_algorithm as select_algorithm
import torch.nn.functional as F
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
torch.backends.cuda.matmul.allow_tf32 = False
def patches(fn):
def skip_cache(self, key, generate):
return generate()
for patcher in [
patch.object(dynamo_config, "log_level", logging.INFO),
patch.object(dynamo_config, "verbose", True),
patch.object(inductor_config, "debug", True),
patch.object(inductor_config, "max_autotune", True),
patch.object(inductor_config, "epilogue_fusion", True),
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
]:
fn = patcher(fn)
@functools.wraps(fn)
def wrapped(*args, **kwargs):
counters.clear()
torch.manual_seed(12345)
assert (
not torch.backends.cuda.matmul.allow_tf32
), "correctness testing is allergic to tf32"
return fn(*args, **kwargs)
return wrapped
class TestSelectAlgorithm(TestCase):
@patches
def test_linear_relu(self):
@torch.compile
def foo(input, weight, bias):
return F.relu(F.linear(input, weight, bias))
foo(
torch.randn(64, 32, device="cuda"),
torch.randn(16, 32, device="cuda"),
torch.randn(16, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
# It would be nice to assert this got fused into a single kernel, but that
# only happens if we select a triton template (and not aten).
@patches
def test_addmm(self):
@torch.compile
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)
foo(
torch.randn(20, 33, device="cuda"),
torch.randn(33, 16, device="cuda"),
torch.randn(20, 16, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(8, 32, device="cuda"),
torch.randn(32, 8, device="cuda"),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_skip(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(8, 32, device="cuda", dtype=torch.float64),
torch.randn(32, 8, device="cuda", dtype=torch.float64),
)
# float64 not supported by tl.dot()
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
@patches
def test_bmm(self):
@torch.compile
def foo(a, b):
return torch.bmm(a, b)
foo(
torch.randn(2, 8, 32, device="cuda"),
torch.randn(2, 32, 8, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_not_even_k(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(11, 22, device="cuda"),
torch.randn(22, 33, device="cuda"),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_baddbmm(self):
@torch.compile
def foo(a, b, c):
return torch.baddbmm(c, a, b)
foo(
torch.randn(2, 8, 32, device="cuda"),
torch.randn(2, 32, 8, device="cuda"),
torch.randn(2, 1, 8, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
run_tests()