pytorch/test/inductor/test_select_algorithm.py
Jason Ansel 8c09a005c5 [inductor] Pattern matching engine (copy) (#93291)
This is an exact duplicate of https://github.com/pytorch/pytorch/pull/90739

The fbcode workflow for landing that diff seems buggy.  The github-export-checks task is failing with credentials errors.  Plan to try to land it using GH1.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93291
Approved by: https://github.com/desertfire
2023-01-31 04:51:00 +00:00

162 lines
4.9 KiB
Python

# Owner(s): ["module: inductor"]
import functools
import logging
from unittest.mock import patch
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.select_algorithm as select_algorithm
import torch.nn.functional as F
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
torch.backends.cuda.matmul.allow_tf32 = False
def patches(fn):
def skip_cache(self, key, generate):
return generate()
for patcher in [
patch.object(dynamo_config, "log_level", logging.INFO),
patch.object(dynamo_config, "verbose", True),
patch.object(inductor_config, "debug", True),
patch.object(inductor_config, "max_autotune", True),
patch.object(inductor_config, "epilogue_fusion", True),
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
]:
fn = patcher(fn)
@functools.wraps(fn)
def wrapped(*args, **kwargs):
counters.clear()
torch.manual_seed(12345)
assert (
not torch.backends.cuda.matmul.allow_tf32
), "correctness testing is allergic to tf32"
return fn(*args, **kwargs)
return wrapped
class TestSelectAlgorithm(TestCase):
@patches
def test_linear_relu(self):
@torch.compile
def foo(input, weight, bias):
return F.relu(F.linear(input, weight, bias))
foo(
torch.randn(64, 32, device="cuda"),
torch.randn(16, 32, device="cuda"),
torch.randn(16, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
# It would be nice to assert this got fused into a single kernel, but that
# only happens if we select a triton template (and not aten).
@patches
def test_addmm(self):
@torch.compile
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)
foo(
torch.randn(20, 33, device="cuda"),
torch.randn(33, 16, device="cuda"),
torch.randn(20, 16, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(8, 32, device="cuda"),
torch.randn(32, 8, device="cuda"),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_skip(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(8, 32, device="cuda", dtype=torch.float64),
torch.randn(32, 8, device="cuda", dtype=torch.float64),
)
# float64 not supported by tl.dot()
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
@patches
def test_bmm(self):
@torch.compile
def foo(a, b):
return torch.bmm(a, b)
foo(
torch.randn(2, 8, 32, device="cuda"),
torch.randn(2, 32, 8, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_not_even_k(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(11, 22, device="cuda"),
torch.randn(22, 33, device="cuda"),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_baddbmm(self):
@torch.compile
def foo(a, b, c):
return torch.baddbmm(c, a, b)
foo(
torch.randn(2, 8, 32, device="cuda"),
torch.randn(2, 32, 8, device="cuda"),
torch.randn(2, 1, 8, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_plus_mm(self):
@torch.compile
def foo(a, b, c, d):
return (a @ b) + (c @ d)
foo(
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
run_tests()