[Inductor] Fix 3D tiling with permute (#147249)

This PR adds a test case and tiny fix for 3D tiling. Before this PR, tiling would crash because one of the candidates lacked a `"y"` dimension. Now, when we're calculating 3D tiling candidates, we assume the y size is 1 if it's missing.

The test case implements a 3D permute using block pointers.

```
@triton.jit
def triton_poi_fused_add_0(in_ptr0, out_ptr0, znumel, ynumel, xnumel, ZBLOCK : tl.constexpr, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    znumel = 51
    ynumel = 51
    xnumel = 51
    zoffset = tl.program_id(2) * ZBLOCK
    zindex = zoffset + tl.arange(0, ZBLOCK)[None, None, :]
    zmask = zindex < znumel
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
    xmask = xindex < xnumel
    x2 = xindex
    y1 = yindex
    z0 = zindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[51, 51, 51], strides=[1, 51, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), boundary_check=[0, 1, 2])
    tmp1 = tl.load(tl.make_block_ptr(in_ptr0, shape=[51, 51, 51], strides=[51, 1, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), boundary_check=[0, 1, 2])
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 + tmp0
    tmp4 = tmp2 + tmp3
    tl.store(tl.make_block_ptr(out_ptr0, shape=[51, 51, 51], strides=[1, 51, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), tl.broadcast_to(tmp4, [XBLOCK, YBLOCK, ZBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147249
Approved by: https://github.com/jansel
This commit is contained in:
Blaine Burton Rister 2025-02-15 23:28:34 +00:00 committed by PyTorch MergeBot
parent 44ee9ca593
commit 1677a31019
2 changed files with 31 additions and 2 deletions

View File

@ -20,6 +20,7 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
requires_gpu,
skip_windows_ci,
TRITON_HAS_CPU,
)
@ -895,6 +896,34 @@ class CommonTemplate:
)
self.assertTrue("Min" not in code[0])
@requires_gpu() # FIXME this test failed on Triton-CPU
def test_3d_permute_tiling(self):
"""
Test 3D tiling with permute.
"""
def foo(x, y, z):
dims = [0, 2, 1]
a = x.permute(dims=dims) + y
b = (z + y).permute(dims=dims)
return a + b
inps = (torch.rand((51, 51, 51), device=self.device, dtype=torch.float32),) * 3
result, (code,) = run_and_compare(
self,
foo,
*inps,
expected_num_triton_kernels=1,
expected_num_block_pointers=3,
config_patches={
"triton.max_tiles": 3,
"triton.prefer_nd_tiling": True,
},
)
# Check for 3D tiling
self.assertIn("ZBLOCK", code)
@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
@config.patch(cpu_backend="triton")

View File

@ -2011,8 +2011,8 @@ class SIMDScheduling(BaseScheduling):
def convert_tiling_to_3d(
tiling0: dict[str, sympy.Expr], tiling1: dict[str, sympy.Expr]
) -> Optional[dict[str, sympy.Expr]]:
a0, a1 = tiling0["x"], tiling0["y"]
b0, b1 = tiling1["x"], tiling1["y"]
a0, a1 = tiling0["x"], tiling0.get("y", 1)
b0, b1 = tiling1["x"], tiling1.get("y", 1)
if V.graph.sizevars.size_hint(a1 - b1) == 0:
return None
if V.graph.sizevars.size_hint(a1 - b1) < 0: