Refactor test_torchinductor_strided_blocks to also support triton CPU (#141587)

This increases test coverage for triton CPU from just test_torchinductor.py to also testing block pointer lowering.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141587
Approved by: https://github.com/jansel
This commit is contained in:
Mwiza Kunda 2024-12-05 09:57:04 +00:00 committed by PyTorch MergeBot
parent 8dd4673cea
commit ad2cc96218
3 changed files with 129 additions and 84 deletions

View File

@ -18,11 +18,17 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
requires_gpu,
skip_windows_ci,
TRITON_HAS_CPU,
)
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor
skip_windows_ci(__name__, __file__)
importlib.import_module("filelock")
@ -30,54 +36,53 @@ importlib.import_module("filelock")
max_block: int = TRITON_MAX_BLOCK["X"]
@requires_gpu()
@config.patch("triton.use_block_ptr", True)
def run_and_compare(
self: InductorTestCase,
func: Callable[..., Any],
*args,
compile_kwargs: Optional[dict] = None,
expected_num_block_pointers: Optional[int] = None,
expected_num_programs: int = 1,
expected_num_triton_kernels: int = 1,
config_patches: Optional[dict] = None,
):
"""
Runs the module through Inductor, comparing to eager reference.
"""
if compile_kwargs is None:
compile_kwargs = {}
if config_patches is None:
config_patches = {}
def flatten_tensors(tensors):
flat, spec = pytree.tree_flatten(tensors)
return flat
with config.patch(config_patches):
compiled = torch.compile(func, backend="inductor", **compile_kwargs)
result, code = run_and_get_code(compiled, *args)
# Check numerical accuracy
ref_tensors = flatten_tensors(func(*args))
actual_tensors = flatten_tensors(result)
for ref, actual in zip(ref_tensors, actual_tensors):
self.assertTrue(torch.allclose(ref, actual))
def count_code(substr: str, expected: Optional[int]):
count = sum(prog.count(substr) for prog in code)
if expected is not None:
self.assertEqual(count, expected)
# Check the code
self.assertEqual(len(code), expected_num_programs)
count_code("@triton.jit", expected_num_triton_kernels)
count_code("tl.make_block_ptr", expected_num_block_pointers)
return result, code
@instantiate_parametrized_tests
class TritonBlockPointerTest(InductorTestCase):
def run_and_compare(
self,
func: Callable[..., Any],
*args,
compile_kwargs: Optional[dict] = None,
expected_num_block_pointers: Optional[int] = None,
expected_num_programs: int = 1,
expected_num_triton_kernels: int = 1,
config_patches: Optional[dict] = None,
):
"""
Runs the module through Inductor, comparing to eager reference.
"""
if compile_kwargs is None:
compile_kwargs = {}
if config_patches is None:
config_patches = {}
def flatten_tensors(tensors):
flat, spec = pytree.tree_flatten(tensors)
return flat
with config.patch(config_patches):
compiled = torch.compile(func, backend="inductor", **compile_kwargs)
result, code = run_and_get_code(compiled, *args)
# Check numerical accuracy
ref_tensors = flatten_tensors(func(*args))
actual_tensors = flatten_tensors(result)
for ref, actual in zip(ref_tensors, actual_tensors):
self.assertTrue(torch.allclose(ref, actual))
def count_code(substr: str, expected: Optional[int]):
count = sum(prog.count(substr) for prog in code)
if expected is not None:
self.assertEqual(count, expected)
# Check the code
self.assertEqual(len(code), expected_num_programs)
count_code("@triton.jit", expected_num_triton_kernels)
count_code("tl.make_block_ptr", expected_num_block_pointers)
return result, code
class CommonTemplate:
@parametrize(
"expected_num_block_pointers,raises",
[
@ -95,14 +100,17 @@ class TritonBlockPointerTest(InductorTestCase):
def foo(x, y):
return x + y
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
inputs = [torch.randn(8).to(device) for arg_idx in range(2)]
# Expect failure for bad inputs
with self.assertRaises(AssertionError) if raises else contextlib.nullcontext():
# Expect 3 block pointers: 2 inputs 1 output
self.run_and_compare(
foo, *inputs, expected_num_block_pointers=expected_num_block_pointers
run_and_compare(
self,
foo,
*inputs,
expected_num_block_pointers=expected_num_block_pointers,
)
@parametrize("prefer_nd_tiling", [False, True])
@ -158,7 +166,7 @@ class TritonBlockPointerTest(InductorTestCase):
"""
def get_input() -> torch.Tensor:
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
full = torch.randn(full_size).to(device)
# Use the original tensor's stride by default
@ -169,7 +177,8 @@ class TritonBlockPointerTest(InductorTestCase):
args = [get_input() for arg_idx in range(2)]
# Expect 3 block pointers: 2 inputs 1 output
self.run_and_compare(
run_and_compare(
self,
torch.add,
*args,
expected_num_block_pointers=3 if require_block_ptr else None,
@ -206,7 +215,7 @@ class TritonBlockPointerTest(InductorTestCase):
return a + b
def get_input(view_size: Tuple[int]) -> torch.Tensor:
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
full_size = tuple(2 * dim for dim in view_size)
full = torch.randn(full_size).to(device)
view = torch.as_strided(full, view_size, full.stride())
@ -222,7 +231,8 @@ class TritonBlockPointerTest(InductorTestCase):
self.assertIn(1, all_dims)
# Expect 3 block pointers: 2 inputs one output
self.run_and_compare(
run_and_compare(
self,
foo,
x,
y,
@ -255,7 +265,7 @@ class TritonBlockPointerTest(InductorTestCase):
return x.expand(y_size).clone()
def get_input(size: Tuple[int]) -> torch.Tensor:
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
full = torch.randn(size).to(device)
view = torch.as_strided(full, size, full.stride())
return view
@ -272,7 +282,7 @@ class TritonBlockPointerTest(InductorTestCase):
if i != 1:
self.assertEqual(i, j)
result, (triton_code,) = self.run_and_compare(foo, x, y)
result, (triton_code,) = run_and_compare(self, foo, x, y)
@parametrize("prefer_nd_tiling", [False, True])
def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool):
@ -282,12 +292,13 @@ class TritonBlockPointerTest(InductorTestCase):
full_shape = (8, 8)
col_shape = (full_shape[1], 1)
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
full = torch.randn(full_shape).to(device)
col = torch.as_strided(full, col_shape, full.stride())
# Expect 3 block pointers: 2 inputs one output
result, (triton_code,) = self.run_and_compare(
result, (triton_code,) = run_and_compare(
self,
torch.add,
full,
col,
@ -353,8 +364,21 @@ class TritonBlockPointerTest(InductorTestCase):
"""
Tests a reduction kernel.
"""
if self.device == "cpu" and all(
# Multiple of max block. Uses loops.
[
view_size == (3 * max_block, 2),
num_block_pointers == 3,
num_triton_kernels == 2,
prefer_nd_tiling is False,
]
):
raise unittest.SkipTest(
"Long test and raises BrokenProcessPool Error if triton CPU"
)
device = torch.device(self.device)
device = torch.device(GPU_TYPE)
full_size = tuple(2 * dim for dim in view_size)
full = torch.randn(full_size).to(device)
view = torch.as_strided(full, view_size, full.stride())
@ -366,7 +390,8 @@ class TritonBlockPointerTest(InductorTestCase):
# Expect at least 1 block pointer for the input.
# Add 2 more if we generate 2 kernels.
result, (code,) = self.run_and_compare(
result, (code,) = run_and_compare(
self,
torch.sum,
view,
expected_num_block_pointers=num_block_pointers,
@ -395,7 +420,7 @@ class TritonBlockPointerTest(InductorTestCase):
def foo(x, y):
return torch.sum(x + y)
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
full_size = tuple(2 * dim for dim in view_size)
def get_input() -> torch.Tensor:
@ -406,7 +431,8 @@ class TritonBlockPointerTest(InductorTestCase):
inputs = [get_input() for input_idx in range(2)]
# Expect 2 block pointers: inputs
result, (code,) = self.run_and_compare(
result, (code,) = run_and_compare(
self,
foo,
*inputs,
expected_num_block_pointers=num_block_pointers,
@ -422,7 +448,7 @@ class TritonBlockPointerTest(InductorTestCase):
def foo(x):
return x - 1
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
full_size = (3 * max_block, 3)
view_size = (3 * max_block, 2)
full = torch.randn(full_size).to(device)
@ -437,7 +463,7 @@ class TritonBlockPointerTest(InductorTestCase):
self.assertTrue(len(nontrivial_dims) > 1)
# Expect 2 block pointers: input and output
self.run_and_compare(foo, view, expected_num_block_pointers=2)
run_and_compare(self, foo, view, expected_num_block_pointers=2)
def test_dynamic_shapes_generic(self):
"""
@ -445,13 +471,13 @@ class TritonBlockPointerTest(InductorTestCase):
expected. This only checks that the analysis doesn't break this case.
"""
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
full_size = (8, 8)
view_size = (4, 4)
full = torch.randn(full_size).to(device)
view = torch.as_strided(full, view_size, full.stride())
self.run_and_compare(torch.div, view, view, compile_kwargs={"dynamic": True})
run_and_compare(self, torch.div, view, view, compile_kwargs={"dynamic": True})
@unittest.skip(reason="Dynamo tracing error")
def test_dynamic_shapes_multiple_max_block(self):
@ -467,13 +493,13 @@ class TritonBlockPointerTest(InductorTestCase):
view = torch.as_strided(full, view_size, full.stride())
return view + view
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
x_size = (1, 1)
x = torch.randn(x_size).to(device)
# Expect 2 block pointers: input and output
self.run_and_compare(
x, compile_kwargs={"dynamic": True}, expected_num_block_pointers=2
run_and_compare(
self, x, compile_kwargs={"dynamic": True}, expected_num_block_pointers=2
)
@parametrize(
@ -520,14 +546,15 @@ class TritonBlockPointerTest(InductorTestCase):
"""
def get_input() -> torch.Tensor:
device = torch.device(GPU_TYPE)
device = torch.device(self.device)
full = torch.randn(full_size).to(device)
return torch.as_strided(full, view_size, full.stride())
args = [get_input() for arg_idx in range(2)]
# Expect up to 3 block pointers: 2 inputs 1 output.
result, code = self.run_and_compare(
result, code = run_and_compare(
self,
torch.add,
*args,
expected_num_block_pointers=num_block_pointers,
@ -558,8 +585,9 @@ class TritonBlockPointerTest(InductorTestCase):
return clone_0, clone_1
inps = (torch.rand((8, 2048), device=GPU_TYPE, dtype=torch.float32),) * 2
result, code = self.run_and_compare(
inps = (torch.rand((8, 2048), device=self.device, dtype=torch.float32),) * 2
result, code = run_and_compare(
self,
func,
*inps,
expected_num_triton_kernels=2,
@ -568,8 +596,26 @@ class TritonBlockPointerTest(InductorTestCase):
self.assertTrue("Min" not in code[0])
@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
@config.patch(cpu_backend="triton")
@config.patch("triton.use_block_ptr", True)
class TritonBlockPointerTestCPU(InductorTestCase):
device = "cpu"
test_torchinductor.copy_tests(CommonTemplate, TritonBlockPointerTestCPU, "cpu")
@unittest.skipIf(not HAS_GPU, "requires triton GPU backend")
@config.patch("triton.use_block_ptr", True)
class TritonBlockPointerTestGPU(InductorTestCase):
device = GPU_TYPE
test_torchinductor.copy_tests(CommonTemplate, TritonBlockPointerTestGPU, GPU_TYPE)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
if HAS_GPU or TRITON_HAS_CPU:
run_tests(needs="filelock")

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: inductor"]
from torch._inductor import config
from torch._inductor.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CPU
from torch.utils._triton import has_triton
from torch.testing._internal.inductor_utils import HAS_CPU, TRITON_HAS_CPU
try:
@ -10,13 +9,6 @@ try:
except ImportError:
import test_torchinductor
if has_triton():
import triton
TRITON_HAS_CPU = "cpu" in triton.backends.backends
else:
TRITON_HAS_CPU = False
if HAS_CPU and TRITON_HAS_CPU:

View File

@ -41,6 +41,13 @@ HAS_CPU = LazyVal(test_cpu)
HAS_TRITON = has_triton()
if HAS_TRITON:
import triton
TRITON_HAS_CPU = "cpu" in triton.backends.backends
else:
TRITON_HAS_CPU = False
HAS_CUDA = torch.cuda.is_available() and HAS_TRITON
HAS_XPU = torch.xpu.is_available() and HAS_TRITON