pytorch/test/inductor/test_benchmark_fusion.py
etaf 1b655a87ef [xpu][test] Enable more UTs for Intel GPU. (#166047)
This PR enables additional Inductor unit tests for Intel GPU. Due to the increased number of test cases, the number of runners has been extended from 8 to 12 to prevent CI timeouts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166047
Approved by: https://github.com/jansel

Co-authored-by: Deng, Daisy <daisy.deng@intel.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
2025-10-29 06:25:36 +00:00

356 lines
12 KiB
Python

# Owner(s): ["module: inductor"]
import math
import os
import sys
import torch
from torch._inductor.codegen.triton import TritonScheduling
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.test_operators import realize
from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import slowTest
from torch.testing._internal.inductor_utils import (
get_func_call,
GPU_TYPE,
HAS_CPU,
HAS_GPU_AND_TRITON,
IS_BIG_GPU,
)
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
import contextlib
import unittest
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
check_model,
check_model_gpu,
copy_tests,
skip_if_cpp_wrapper,
)
from torch._inductor import config
from torch._inductor.scheduler import Scheduler
class TestCase(InductorTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"benchmark_kernel": True,
"benchmark_fusion": True,
}
)
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
class BenchmarkFusionTestTemplate:
def test_softmax(self):
def f(x):
return torch.nn.functional.softmax(x, dim=-1)
self.common(f, (torch.rand(2, 8192),))
@slowTest
def test_resnet18(self):
try:
import torchvision
except ImportError:
self.skipTest("TorchVision not available")
model = torchvision.models.resnet18()
model.eval()
batch_size = 16
inputs = (torch.randn((batch_size, 3, 224, 224)),)
self.common(model, inputs, atol=1e-2, rtol=1e-2)
def test_register_spills(self):
"""
The test can potentially trigger register spills
"""
old_benchmark_fn = Scheduler.benchmark_fused_nodes
def new_benchmark_fn(scheduler, nodes):
"""
We override Scheduler.benchmark_fused_nodes to return latency 1.0
if there are no register spills. Without this, we may not able to
test the code path handling register spilling because before register
start spilling, the related fusion may have already been skipped
due to longer lantency.
"""
ms, path = old_benchmark_fn(scheduler, nodes)
if not math.isinf(ms):
ms = 1.0
return ms, path
# Disable dynamic_scale_rblock to make it easier to trigger register
# spilling.
with (
unittest.mock.patch.object(
Scheduler, "benchmark_fused_nodes", new_benchmark_fn
),
config.patch("dynamic_scale_rblock", False),
):
S = 512
def f(*inputs):
inputs = list(inputs)
outputs = []
out = torch.zeros(S, device=self.device)
for x in inputs:
x = x * 2
x = x + 1
x = x.sum(dim=-1)
outputs.append(x)
out = out + x
return outputs, out
N = int(os.environ.get("NINP", "30"))
inputs = [torch.randn(S, 2560, device=self.device) for _ in range(N)]
opt_f = torch.compile(f)
opt_f(*inputs)
def test_foreach_kernel(self):
"""
Benchmark fusion should skip benchmarking kernels involves foreach kernel
for now. Without the skipping logic, `codegen_node_schedule` may fail.
"""
a = torch.randn(1024, 256, device=self.device)
b = torch.randn(1024, 512, device=self.device)
def f(a, b):
a, b = torch._foreach_abs([a, b])
return a + 1, b + 2
self.common(f, (a, b))
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
@config.patch(max_autotune_gemm_backends="TRITON")
def test_avoid_register_spilling(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("GPU only")
from torch.nn.functional import gelu
def foo(m, inp):
curr = m(inp)
tmps = []
for _ in range(4):
curr = gelu(curr)
for t in tmps:
curr = curr + t
tmps.append(curr)
return curr
m = torch.nn.Linear(2048, 2048, bias=True).half().to(GPU_TYPE)
inp = torch.rand([2048, 2048]).half().to(GPU_TYPE)
with torch.no_grad():
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
_, out_code = run_and_get_code(foo_c, m, inp)
# occasionally, CI will make this one kernel. just skip in this case
if out_code[0].count("def triton_") != 2:
return
# should be multiple triton invocations
FileCheck().check("async_compile.wait").check_count(
".run", 2, exactly=True
).run(out_code[0])
with (
config.patch({"benchmark_fusion": False, "epilogue_fusion": False}),
torch.no_grad(),
):
torch._dynamo.reset()
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
_, out_code2 = run_and_get_code(foo_c, m, inp)
for c in out_code[0], out_code2[0]:
FileCheck().check("async_compile.wait").check("DeviceGuard").check_count(
f"empty_strided_{GPU_TYPE}", 1, exactly=True
).check_regex("buf[0-9]* = buf[0-9]*; del buf[0-9]*").check("return").run(c)
def test_tield_kernel_fusion(self):
def f(x):
y = realize(x + x.t())
return y + 1
x = torch.randn(1024, 1024, device=self.device)
self.common(f, (x,))
if HAS_GPU_AND_TRITON:
class BenchmarkFusionGpuTest(TestCase):
common = check_model_gpu
device = GPU_TYPE
copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionGpuTest, GPU_TYPE)
class BenchmarkingTest(TestCase):
@unittest.skipIf(
getattr(torch, GPU_TYPE).device_count() < 2,
"The test need at least 2 devices",
)
@skip_if_cpp_wrapper("This tests triton scheduling directly")
def test_benchmark_on_non_zero_device(self):
hit_count = 0
with getattr(torch, GPU_TYPE).device(f"{GPU_TYPE}:0"):
@torch.compile
def relu(x):
return realize(x.relu()) + x
x = torch.randn(int(16e6), device=f"{GPU_TYPE}:1")
orig_benchmark_codegened_module = (
TritonScheduling.benchmark_codegened_module
)
def benchmark_codegened_module(*args, **kwargs):
nonlocal hit_count
hit_count += 1
ms, path = orig_benchmark_codegened_module(*args, **kwargs)
self.assertTrue(ms > 0)
return ms, path
with unittest.mock.patch.object(
TritonScheduling,
"benchmark_codegened_module",
benchmark_codegened_module,
):
relu(x)
self.assertTrue(hit_count > 0)
class BenchmarkMultiTemplateFusionGpuTest(InductorTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"benchmark_kernel": True,
"benchmark_fusion": True,
"benchmark_epilogue_fusion": True,
}
)
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
def setUp(self):
super().setUp()
if not is_big_gpu():
return self.skipTest("Need a big GPU to run max_autotune=True")
def _equivalent_output_code_impl(self, size, first_dim=None, activation=True):
def foo(m, inp):
a = m(inp)
if activation:
return torch.nn.functional.relu(a)
return a
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
first_dim = first_dim if first_dim is not None else size
m = torch.nn.Linear(size, size, bias=True).half().to(GPU_TYPE)
inp = torch.rand([first_dim, size]).half().to(GPU_TYPE)
with torch.no_grad():
res, code = run_and_get_code(foo_c, m, inp)
torch._dynamo.reset()
with config.patch(benchmark_epilogue_fusion=False):
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
with torch.no_grad():
res2, code2 = run_and_get_code(foo_c, m, inp)
self.assertEqual(res, res2, atol=1e-4, rtol=1.1)
return code, code2
@fresh_cache()
@config.patch(max_autotune_gemm_backends="TRITON")
def test_equivalent_template_code(self):
code, code2 = self._equivalent_output_code_impl(256)
for out_code in [code, code2]:
FileCheck().check(get_func_call()).check_count(
"empty_strided", 1, exactly=True
).check("triton_tem_fused_addmm_relu_t_0").check_count(
".reset()" if config.cpp_wrapper else "del", 3, exactly=True
).check("" if config.cpp_wrapper else "return").run(out_code[0])
@fresh_cache()
@config.patch(max_autotune_gemm_backends="ATEN")
def test_equivalent_extern_code(self):
torch._dynamo.reset()
code, code2 = self._equivalent_output_code_impl(512, 1, False)
for out_code in [code, code2]:
FileCheck().check(get_func_call()).check_count(
"empty_strided", 1, exactly=True
).check("" if config.cpp_wrapper else "extern_kernels.").check_count(
".reset()" if config.cpp_wrapper else "del", 3, exactly=True
).check("" if config.cpp_wrapper else "return").run(out_code[0])
def test_changed_layout(self):
# cat addmm planning will change layout - make sure propagated
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
],
1,
)
args = [
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(4, 4, device=GPU_TYPE),
]
expected = fn(*args)
actual = torch.compile(fn, mode="max-autotune")(*args)
self.assertEqual(expected, actual)
torch._dynamo.reset()
if HAS_CPU and not torch.backends.mps.is_available():
class BenchmarkFusionCpuTest(TestCase):
common = check_model
device = "cpu"
copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCpuTest, "cpu")
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU or HAS_GPU_AND_TRITON:
run_tests()