# Owner(s): ["module: inductor"] from typing import Callable import torch from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_triton_code from torch.testing import FileCheck from torch.testing._internal.common_utils import skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU aten = torch.ops.aten @inductor_config.patch({"triton.native_matmul": True}) class TestTritonDotReduction(TestCase): def _check_equal( self, f: Callable, example_inputs: tuple[torch.Tensor], ): compiled = torch.compile(f) actual = compiled(*example_inputs) expect = f(*example_inputs) self.assertTrue(same(expect, actual)) def _check_code( self, f: Callable, example_inputs: tuple[torch.Tensor], kernel_count: int, dot_count: int, ): f = torch.compile(f) code = run_and_get_triton_code(f, *example_inputs) FileCheck().check_regex(r"triton.*mm.*\.run\(").run(code) FileCheck().check_count( "@triton.jit", kernel_count, ).check_count( "tl.dot", dot_count, ).run(code) def test_matmul(self): def f(x, y): z = x @ y return z M, K, N = 128, 128, 128 x = rand_strided((M, K), (K, 1), device=GPU_TYPE) y = rand_strided((K, N), (N, 1), device=GPU_TYPE) self._check_equal(f, (x, y)) self._check_code(f, (x, y), 1, 1) def test_mm_1d_expand(self): def f(x, y, M, K): z = x[:, None].expand(M, K) @ y return z M, K, N = 128, 128, 128 x = rand_strided((M,), (1,), device=GPU_TYPE) y = rand_strided((K, N), (N, 1), device=GPU_TYPE) self._check_equal(f, (x, y, M, K)) self._check_code(f, (x, y, M, K), 1, 1) def test_mm_2_expand(self): def f(x, y, M, K): z = x[:, None].expand(M, K) @ y return z M, K, N = 128, 128, 128 x = rand_strided((1,), (0,), device=GPU_TYPE) y = rand_strided((K, N), (N, 1), device=GPU_TYPE) self._check_equal(f, (x, y, M, K)) self._check_code(f, (x, y, M, K), 1, 1) def test_matmul_fp16(self): def f(x, y): z = x @ y.to(x.dtype) return z M, K, N = 128, 128, 128 x = rand_strided((M, K), (K, 1), dtype=torch.float16, device=GPU_TYPE) y = rand_strided((K, N), (N, 1), dtype=torch.float32, device=GPU_TYPE) self._check_equal(f, (x, y)) self._check_code(f, (x, y), 1, 1) def test_reduction_mask_zeroout(self): def f(x, y): return (x + 1) @ (y - 2) M, K, N = 62, 62, 62 x = rand_strided((M, K), (K, 1), device=GPU_TYPE) y = rand_strided((K, N), (N, 1), device=GPU_TYPE) self._check_equal(f, (x, y)) self._check_code(f, (x, y), 1, 1) @skipIfXpu( msg="Intel triton issue: https://github.com/intel/intel-xpu-backend-for-triton/issues/5394" ) def test_3mm_add(self): def f(x, y, z, w, r, t): return x @ y + z @ w + r @ t M, K, N = 128, 128, 128 x = rand_strided((M, K), (K, 1), device=GPU_TYPE) y = rand_strided((K, N), (N, 1), device=GPU_TYPE) w = rand_strided((M, K), (K, 1), device=GPU_TYPE) z = rand_strided((K, N), (N, 1), device=GPU_TYPE) r = rand_strided((M, K), (K, 1), device=GPU_TYPE) t = rand_strided((K, N), (N, 1), device=GPU_TYPE) self._check_equal(f, (x, y, z, w, r, t)) self._check_code(f, (x, y, z, w, r, t), 1, 3) def test_mm_complex(self): def f(x, y, z, w): return x[z] @ y + w + 3 M, K, N = 128, 128, 128 x = rand_strided((M, K), (K, 1), device=GPU_TYPE) y = rand_strided((K, N), (N, 1), device=GPU_TYPE) z = torch.randint(M, (M, K), dtype=torch.long, device=GPU_TYPE) w = rand_strided((M, N), (N, 1), device=GPU_TYPE) self._check_equal(f, (x, y, z, w)) self._check_code(f, (x, y, z, w), 1, 1) def test_batchmatmul(self): def f(x, y): z = torch.bmm(x, y) return z B, M, K, N = 256, 128, 128, 128 x = rand_strided((B, M, K), (M * K, K, 1), device=GPU_TYPE) y = rand_strided((B, K, N), (K * N, N, 1), device=GPU_TYPE) self._check_equal(f, (x, y)) self._check_code(f, (x, y), 1, 1) if HAS_GPU: torch.set_default_device(GPU_TYPE) if __name__ == "__main__": if HAS_GPU: run_tests()