mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refactor test_torchinductor_strided_blocks to also support triton CPU (#141587)
This increases test coverage for triton CPU from just test_torchinductor.py to also testing block pointer lowering. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/141587 Approved by: https://github.com/jansel
This commit is contained in:
parent
8dd4673cea
commit
ad2cc96218
|
|
@ -18,11 +18,17 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_GPU,
|
||||
requires_gpu,
|
||||
skip_windows_ci,
|
||||
TRITON_HAS_CPU,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from . import test_torchinductor
|
||||
except ImportError:
|
||||
import test_torchinductor
|
||||
|
||||
|
||||
skip_windows_ci(__name__, __file__)
|
||||
|
||||
importlib.import_module("filelock")
|
||||
|
|
@ -30,54 +36,53 @@ importlib.import_module("filelock")
|
|||
max_block: int = TRITON_MAX_BLOCK["X"]
|
||||
|
||||
|
||||
@requires_gpu()
|
||||
@config.patch("triton.use_block_ptr", True)
|
||||
def run_and_compare(
|
||||
self: InductorTestCase,
|
||||
func: Callable[..., Any],
|
||||
*args,
|
||||
compile_kwargs: Optional[dict] = None,
|
||||
expected_num_block_pointers: Optional[int] = None,
|
||||
expected_num_programs: int = 1,
|
||||
expected_num_triton_kernels: int = 1,
|
||||
config_patches: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Runs the module through Inductor, comparing to eager reference.
|
||||
"""
|
||||
if compile_kwargs is None:
|
||||
compile_kwargs = {}
|
||||
if config_patches is None:
|
||||
config_patches = {}
|
||||
|
||||
def flatten_tensors(tensors):
|
||||
flat, spec = pytree.tree_flatten(tensors)
|
||||
return flat
|
||||
|
||||
with config.patch(config_patches):
|
||||
compiled = torch.compile(func, backend="inductor", **compile_kwargs)
|
||||
result, code = run_and_get_code(compiled, *args)
|
||||
|
||||
# Check numerical accuracy
|
||||
ref_tensors = flatten_tensors(func(*args))
|
||||
actual_tensors = flatten_tensors(result)
|
||||
for ref, actual in zip(ref_tensors, actual_tensors):
|
||||
self.assertTrue(torch.allclose(ref, actual))
|
||||
|
||||
def count_code(substr: str, expected: Optional[int]):
|
||||
count = sum(prog.count(substr) for prog in code)
|
||||
if expected is not None:
|
||||
self.assertEqual(count, expected)
|
||||
|
||||
# Check the code
|
||||
self.assertEqual(len(code), expected_num_programs)
|
||||
count_code("@triton.jit", expected_num_triton_kernels)
|
||||
count_code("tl.make_block_ptr", expected_num_block_pointers)
|
||||
|
||||
return result, code
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TritonBlockPointerTest(InductorTestCase):
|
||||
def run_and_compare(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
*args,
|
||||
compile_kwargs: Optional[dict] = None,
|
||||
expected_num_block_pointers: Optional[int] = None,
|
||||
expected_num_programs: int = 1,
|
||||
expected_num_triton_kernels: int = 1,
|
||||
config_patches: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Runs the module through Inductor, comparing to eager reference.
|
||||
"""
|
||||
if compile_kwargs is None:
|
||||
compile_kwargs = {}
|
||||
if config_patches is None:
|
||||
config_patches = {}
|
||||
|
||||
def flatten_tensors(tensors):
|
||||
flat, spec = pytree.tree_flatten(tensors)
|
||||
return flat
|
||||
|
||||
with config.patch(config_patches):
|
||||
compiled = torch.compile(func, backend="inductor", **compile_kwargs)
|
||||
result, code = run_and_get_code(compiled, *args)
|
||||
|
||||
# Check numerical accuracy
|
||||
ref_tensors = flatten_tensors(func(*args))
|
||||
actual_tensors = flatten_tensors(result)
|
||||
for ref, actual in zip(ref_tensors, actual_tensors):
|
||||
self.assertTrue(torch.allclose(ref, actual))
|
||||
|
||||
def count_code(substr: str, expected: Optional[int]):
|
||||
count = sum(prog.count(substr) for prog in code)
|
||||
if expected is not None:
|
||||
self.assertEqual(count, expected)
|
||||
|
||||
# Check the code
|
||||
self.assertEqual(len(code), expected_num_programs)
|
||||
count_code("@triton.jit", expected_num_triton_kernels)
|
||||
count_code("tl.make_block_ptr", expected_num_block_pointers)
|
||||
|
||||
return result, code
|
||||
|
||||
class CommonTemplate:
|
||||
@parametrize(
|
||||
"expected_num_block_pointers,raises",
|
||||
[
|
||||
|
|
@ -95,14 +100,17 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
def foo(x, y):
|
||||
return x + y
|
||||
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
inputs = [torch.randn(8).to(device) for arg_idx in range(2)]
|
||||
|
||||
# Expect failure for bad inputs
|
||||
with self.assertRaises(AssertionError) if raises else contextlib.nullcontext():
|
||||
# Expect 3 block pointers: 2 inputs 1 output
|
||||
self.run_and_compare(
|
||||
foo, *inputs, expected_num_block_pointers=expected_num_block_pointers
|
||||
run_and_compare(
|
||||
self,
|
||||
foo,
|
||||
*inputs,
|
||||
expected_num_block_pointers=expected_num_block_pointers,
|
||||
)
|
||||
|
||||
@parametrize("prefer_nd_tiling", [False, True])
|
||||
|
|
@ -158,7 +166,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
"""
|
||||
|
||||
def get_input() -> torch.Tensor:
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
full = torch.randn(full_size).to(device)
|
||||
|
||||
# Use the original tensor's stride by default
|
||||
|
|
@ -169,7 +177,8 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
args = [get_input() for arg_idx in range(2)]
|
||||
|
||||
# Expect 3 block pointers: 2 inputs 1 output
|
||||
self.run_and_compare(
|
||||
run_and_compare(
|
||||
self,
|
||||
torch.add,
|
||||
*args,
|
||||
expected_num_block_pointers=3 if require_block_ptr else None,
|
||||
|
|
@ -206,7 +215,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
return a + b
|
||||
|
||||
def get_input(view_size: Tuple[int]) -> torch.Tensor:
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
full_size = tuple(2 * dim for dim in view_size)
|
||||
full = torch.randn(full_size).to(device)
|
||||
view = torch.as_strided(full, view_size, full.stride())
|
||||
|
|
@ -222,7 +231,8 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
self.assertIn(1, all_dims)
|
||||
|
||||
# Expect 3 block pointers: 2 inputs one output
|
||||
self.run_and_compare(
|
||||
run_and_compare(
|
||||
self,
|
||||
foo,
|
||||
x,
|
||||
y,
|
||||
|
|
@ -255,7 +265,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
return x.expand(y_size).clone()
|
||||
|
||||
def get_input(size: Tuple[int]) -> torch.Tensor:
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
full = torch.randn(size).to(device)
|
||||
view = torch.as_strided(full, size, full.stride())
|
||||
return view
|
||||
|
|
@ -272,7 +282,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
if i != 1:
|
||||
self.assertEqual(i, j)
|
||||
|
||||
result, (triton_code,) = self.run_and_compare(foo, x, y)
|
||||
result, (triton_code,) = run_and_compare(self, foo, x, y)
|
||||
|
||||
@parametrize("prefer_nd_tiling", [False, True])
|
||||
def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool):
|
||||
|
|
@ -282,12 +292,13 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
|
||||
full_shape = (8, 8)
|
||||
col_shape = (full_shape[1], 1)
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
full = torch.randn(full_shape).to(device)
|
||||
col = torch.as_strided(full, col_shape, full.stride())
|
||||
|
||||
# Expect 3 block pointers: 2 inputs one output
|
||||
result, (triton_code,) = self.run_and_compare(
|
||||
result, (triton_code,) = run_and_compare(
|
||||
self,
|
||||
torch.add,
|
||||
full,
|
||||
col,
|
||||
|
|
@ -353,8 +364,21 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
"""
|
||||
Tests a reduction kernel.
|
||||
"""
|
||||
if self.device == "cpu" and all(
|
||||
# Multiple of max block. Uses loops.
|
||||
[
|
||||
view_size == (3 * max_block, 2),
|
||||
num_block_pointers == 3,
|
||||
num_triton_kernels == 2,
|
||||
prefer_nd_tiling is False,
|
||||
]
|
||||
):
|
||||
raise unittest.SkipTest(
|
||||
"Long test and raises BrokenProcessPool Error if triton CPU"
|
||||
)
|
||||
|
||||
device = torch.device(self.device)
|
||||
|
||||
device = torch.device(GPU_TYPE)
|
||||
full_size = tuple(2 * dim for dim in view_size)
|
||||
full = torch.randn(full_size).to(device)
|
||||
view = torch.as_strided(full, view_size, full.stride())
|
||||
|
|
@ -366,7 +390,8 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
|
||||
# Expect at least 1 block pointer for the input.
|
||||
# Add 2 more if we generate 2 kernels.
|
||||
result, (code,) = self.run_and_compare(
|
||||
result, (code,) = run_and_compare(
|
||||
self,
|
||||
torch.sum,
|
||||
view,
|
||||
expected_num_block_pointers=num_block_pointers,
|
||||
|
|
@ -395,7 +420,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
def foo(x, y):
|
||||
return torch.sum(x + y)
|
||||
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
full_size = tuple(2 * dim for dim in view_size)
|
||||
|
||||
def get_input() -> torch.Tensor:
|
||||
|
|
@ -406,7 +431,8 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
inputs = [get_input() for input_idx in range(2)]
|
||||
|
||||
# Expect 2 block pointers: inputs
|
||||
result, (code,) = self.run_and_compare(
|
||||
result, (code,) = run_and_compare(
|
||||
self,
|
||||
foo,
|
||||
*inputs,
|
||||
expected_num_block_pointers=num_block_pointers,
|
||||
|
|
@ -422,7 +448,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
def foo(x):
|
||||
return x - 1
|
||||
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
full_size = (3 * max_block, 3)
|
||||
view_size = (3 * max_block, 2)
|
||||
full = torch.randn(full_size).to(device)
|
||||
|
|
@ -437,7 +463,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
self.assertTrue(len(nontrivial_dims) > 1)
|
||||
|
||||
# Expect 2 block pointers: input and output
|
||||
self.run_and_compare(foo, view, expected_num_block_pointers=2)
|
||||
run_and_compare(self, foo, view, expected_num_block_pointers=2)
|
||||
|
||||
def test_dynamic_shapes_generic(self):
|
||||
"""
|
||||
|
|
@ -445,13 +471,13 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
expected. This only checks that the analysis doesn't break this case.
|
||||
"""
|
||||
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
full_size = (8, 8)
|
||||
view_size = (4, 4)
|
||||
full = torch.randn(full_size).to(device)
|
||||
view = torch.as_strided(full, view_size, full.stride())
|
||||
|
||||
self.run_and_compare(torch.div, view, view, compile_kwargs={"dynamic": True})
|
||||
run_and_compare(self, torch.div, view, view, compile_kwargs={"dynamic": True})
|
||||
|
||||
@unittest.skip(reason="Dynamo tracing error")
|
||||
def test_dynamic_shapes_multiple_max_block(self):
|
||||
|
|
@ -467,13 +493,13 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
view = torch.as_strided(full, view_size, full.stride())
|
||||
return view + view
|
||||
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
x_size = (1, 1)
|
||||
x = torch.randn(x_size).to(device)
|
||||
|
||||
# Expect 2 block pointers: input and output
|
||||
self.run_and_compare(
|
||||
x, compile_kwargs={"dynamic": True}, expected_num_block_pointers=2
|
||||
run_and_compare(
|
||||
self, x, compile_kwargs={"dynamic": True}, expected_num_block_pointers=2
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
|
|
@ -520,14 +546,15 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
"""
|
||||
|
||||
def get_input() -> torch.Tensor:
|
||||
device = torch.device(GPU_TYPE)
|
||||
device = torch.device(self.device)
|
||||
full = torch.randn(full_size).to(device)
|
||||
return torch.as_strided(full, view_size, full.stride())
|
||||
|
||||
args = [get_input() for arg_idx in range(2)]
|
||||
|
||||
# Expect up to 3 block pointers: 2 inputs 1 output.
|
||||
result, code = self.run_and_compare(
|
||||
result, code = run_and_compare(
|
||||
self,
|
||||
torch.add,
|
||||
*args,
|
||||
expected_num_block_pointers=num_block_pointers,
|
||||
|
|
@ -558,8 +585,9 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
|
||||
return clone_0, clone_1
|
||||
|
||||
inps = (torch.rand((8, 2048), device=GPU_TYPE, dtype=torch.float32),) * 2
|
||||
result, code = self.run_and_compare(
|
||||
inps = (torch.rand((8, 2048), device=self.device, dtype=torch.float32),) * 2
|
||||
result, code = run_and_compare(
|
||||
self,
|
||||
func,
|
||||
*inps,
|
||||
expected_num_triton_kernels=2,
|
||||
|
|
@ -568,8 +596,26 @@ class TritonBlockPointerTest(InductorTestCase):
|
|||
self.assertTrue("Min" not in code[0])
|
||||
|
||||
|
||||
@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
|
||||
@config.patch(cpu_backend="triton")
|
||||
@config.patch("triton.use_block_ptr", True)
|
||||
class TritonBlockPointerTestCPU(InductorTestCase):
|
||||
device = "cpu"
|
||||
|
||||
|
||||
test_torchinductor.copy_tests(CommonTemplate, TritonBlockPointerTestCPU, "cpu")
|
||||
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "requires triton GPU backend")
|
||||
@config.patch("triton.use_block_ptr", True)
|
||||
class TritonBlockPointerTestGPU(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
|
||||
|
||||
test_torchinductor.copy_tests(CommonTemplate, TritonBlockPointerTestGPU, GPU_TYPE)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU:
|
||||
if HAS_GPU or TRITON_HAS_CPU:
|
||||
run_tests(needs="filelock")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
from torch._inductor import config
|
||||
from torch._inductor.test_case import run_tests
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
from torch.utils._triton import has_triton
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, TRITON_HAS_CPU
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -10,13 +9,6 @@ try:
|
|||
except ImportError:
|
||||
import test_torchinductor
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
|
||||
TRITON_HAS_CPU = "cpu" in triton.backends.backends
|
||||
else:
|
||||
TRITON_HAS_CPU = False
|
||||
|
||||
|
||||
if HAS_CPU and TRITON_HAS_CPU:
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,13 @@ HAS_CPU = LazyVal(test_cpu)
|
|||
|
||||
HAS_TRITON = has_triton()
|
||||
|
||||
if HAS_TRITON:
|
||||
import triton
|
||||
TRITON_HAS_CPU = "cpu" in triton.backends.backends
|
||||
else:
|
||||
TRITON_HAS_CPU = False
|
||||
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available() and HAS_TRITON
|
||||
|
||||
HAS_XPU = torch.xpu.is_available() and HAS_TRITON
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user