pytorch/test/inductor/test_native_matmul.py
xinan.lin 0918bf321c [xpu][test] Reuse native_mm and mix_order_reduction for Intel GPU. (#166384)
This PR reused native_mm and mix_order_reduction for Intel GPU and enabled the corresonding test.
Fixes #165370

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166384
Approved by: https://github.com/jansel
2025-10-30 03:38:35 +00:00

161 lines
4.7 KiB
Python

# Owner(s): ["module: inductor"]
from typing import Callable
import torch
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
from torch._inductor import config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_triton_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
aten = torch.ops.aten
@inductor_config.patch({"triton.native_matmul": True})
class TestTritonDotReduction(TestCase):
def _check_equal(
self,
f: Callable,
example_inputs: tuple[torch.Tensor],
):
compiled = torch.compile(f)
actual = compiled(*example_inputs)
expect = f(*example_inputs)
self.assertTrue(same(expect, actual))
def _check_code(
self,
f: Callable,
example_inputs: tuple[torch.Tensor],
kernel_count: int,
dot_count: int,
):
f = torch.compile(f)
code = run_and_get_triton_code(f, *example_inputs)
FileCheck().check_regex(r"triton.*mm.*\.run\(").run(code)
FileCheck().check_count(
"@triton.jit",
kernel_count,
).check_count(
"tl.dot",
dot_count,
).run(code)
def test_matmul(self):
def f(x, y):
z = x @ y
return z
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y))
self._check_code(f, (x, y), 1, 1)
def test_mm_1d_expand(self):
def f(x, y, M, K):
z = x[:, None].expand(M, K) @ y
return z
M, K, N = 128, 128, 128
x = rand_strided((M,), (1,), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y, M, K))
self._check_code(f, (x, y, M, K), 1, 1)
def test_mm_2_expand(self):
def f(x, y, M, K):
z = x[:, None].expand(M, K) @ y
return z
M, K, N = 128, 128, 128
x = rand_strided((1,), (0,), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y, M, K))
self._check_code(f, (x, y, M, K), 1, 1)
def test_matmul_fp16(self):
def f(x, y):
z = x @ y.to(x.dtype)
return z
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), dtype=torch.float16, device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), dtype=torch.float32, device=GPU_TYPE)
self._check_equal(f, (x, y))
self._check_code(f, (x, y), 1, 1)
def test_reduction_mask_zeroout(self):
def f(x, y):
return (x + 1) @ (y - 2)
M, K, N = 62, 62, 62
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y))
self._check_code(f, (x, y), 1, 1)
@skipIfXpu(
msg="Intel triton issue: https://github.com/intel/intel-xpu-backend-for-triton/issues/5394"
)
def test_3mm_add(self):
def f(x, y, z, w, r, t):
return x @ y + z @ w + r @ t
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
w = rand_strided((M, K), (K, 1), device=GPU_TYPE)
z = rand_strided((K, N), (N, 1), device=GPU_TYPE)
r = rand_strided((M, K), (K, 1), device=GPU_TYPE)
t = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y, z, w, r, t))
self._check_code(f, (x, y, z, w, r, t), 1, 3)
def test_mm_complex(self):
def f(x, y, z, w):
return x[z] @ y + w + 3
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
z = torch.randint(M, (M, K), dtype=torch.long, device=GPU_TYPE)
w = rand_strided((M, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y, z, w))
self._check_code(f, (x, y, z, w), 1, 1)
def test_batchmatmul(self):
def f(x, y):
z = torch.bmm(x, y)
return z
B, M, K, N = 256, 128, 128, 128
x = rand_strided((B, M, K), (M * K, K, 1), device=GPU_TYPE)
y = rand_strided((B, K, N), (K * N, N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y))
self._check_code(f, (x, y), 1, 1)
if HAS_GPU:
torch.set_default_device(GPU_TYPE)
if __name__ == "__main__":
if HAS_GPU:
run_tests()