mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
https://github.com/pytorch/pytorch/pull/90738 seems a bit borked. ghimport fails on it, and I unlinked it from the Phabricator diff, but it still won't land. This is an exact copy that PR without using ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91105 Approved by: https://github.com/ngimel
147 lines
4.5 KiB
Python
147 lines
4.5 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import functools
|
|
import logging
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo.config as dynamo_config
|
|
import torch._inductor.config as inductor_config
|
|
import torch._inductor.select_algorithm as select_algorithm
|
|
import torch.nn.functional as F
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._dynamo.utils import counters
|
|
from torch.testing._internal.common_utils import IS_LINUX
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
|
|
def patches(fn):
|
|
def skip_cache(self, key, generate):
|
|
return generate()
|
|
|
|
for patcher in [
|
|
patch.object(dynamo_config, "log_level", logging.INFO),
|
|
patch.object(dynamo_config, "verbose", True),
|
|
patch.object(inductor_config, "debug", True),
|
|
patch.object(inductor_config, "max_autotune", True),
|
|
patch.object(inductor_config, "epilogue_fusion", True),
|
|
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
|
|
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
|
|
]:
|
|
fn = patcher(fn)
|
|
|
|
@functools.wraps(fn)
|
|
def wrapped(*args, **kwargs):
|
|
counters.clear()
|
|
torch.manual_seed(12345)
|
|
assert (
|
|
not torch.backends.cuda.matmul.allow_tf32
|
|
), "correctness testing is allergic to tf32"
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
|
|
class TestSelectAlgorithm(TestCase):
|
|
@patches
|
|
def test_linear_relu(self):
|
|
@torch.compile
|
|
def foo(input, weight, bias):
|
|
return F.relu(F.linear(input, weight, bias))
|
|
|
|
foo(
|
|
torch.randn(64, 32, device="cuda"),
|
|
torch.randn(16, 32, device="cuda"),
|
|
torch.randn(16, device="cuda"),
|
|
)
|
|
# Autotuning checks correctness of each version
|
|
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
|
# It would be nice to assert this got fused into a single kernel, but that
|
|
# only happens if we select a triton template (and not aten).
|
|
|
|
@patches
|
|
def test_addmm(self):
|
|
@torch.compile
|
|
def foo(input, weight, bias):
|
|
return torch.addmm(bias, input, weight)
|
|
|
|
foo(
|
|
torch.randn(20, 33, device="cuda"),
|
|
torch.randn(33, 16, device="cuda"),
|
|
torch.randn(20, 16, device="cuda"),
|
|
)
|
|
# Autotuning checks correctness of each version
|
|
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
|
|
|
@patches
|
|
def test_mm(self):
|
|
@torch.compile
|
|
def foo(a, b):
|
|
return torch.mm(a, b)
|
|
|
|
foo(
|
|
torch.randn(8, 32, device="cuda"),
|
|
torch.randn(32, 8, device="cuda"),
|
|
)
|
|
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
|
|
|
@patches
|
|
def test_mm_skip(self):
|
|
@torch.compile
|
|
def foo(a, b):
|
|
return torch.mm(a, b)
|
|
|
|
foo(
|
|
torch.randn(8, 32, device="cuda", dtype=torch.float64),
|
|
torch.randn(32, 8, device="cuda", dtype=torch.float64),
|
|
)
|
|
# float64 not supported by tl.dot()
|
|
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
|
|
|
|
@patches
|
|
def test_bmm(self):
|
|
@torch.compile
|
|
def foo(a, b):
|
|
return torch.bmm(a, b)
|
|
|
|
foo(
|
|
torch.randn(2, 8, 32, device="cuda"),
|
|
torch.randn(2, 32, 8, device="cuda"),
|
|
)
|
|
# Autotuning checks correctness of each version
|
|
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
|
|
|
@patches
|
|
def test_mm_not_even_k(self):
|
|
@torch.compile
|
|
def foo(a, b):
|
|
return torch.mm(a, b)
|
|
|
|
foo(
|
|
torch.randn(11, 22, device="cuda"),
|
|
torch.randn(22, 33, device="cuda"),
|
|
)
|
|
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
|
|
|
@patches
|
|
def test_baddbmm(self):
|
|
@torch.compile
|
|
def foo(a, b, c):
|
|
return torch.baddbmm(c, a, b)
|
|
|
|
foo(
|
|
torch.randn(2, 8, 32, device="cuda"),
|
|
torch.randn(2, 32, 8, device="cuda"),
|
|
torch.randn(2, 1, 8, device="cuda"),
|
|
)
|
|
# Autotuning checks correctness of each version
|
|
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.utils import is_big_gpu
|
|
|
|
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
|
|
run_tests()
|