mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Fix 3D tiling with permute (#147249)
This PR adds a test case and tiny fix for 3D tiling. Before this PR, tiling would crash because one of the candidates lacked a `"y"` dimension. Now, when we're calculating 3D tiling candidates, we assume the y size is 1 if it's missing.
The test case implements a 3D permute using block pointers.
```
@triton.jit
def triton_poi_fused_add_0(in_ptr0, out_ptr0, znumel, ynumel, xnumel, ZBLOCK : tl.constexpr, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
znumel = 51
ynumel = 51
xnumel = 51
zoffset = tl.program_id(2) * ZBLOCK
zindex = zoffset + tl.arange(0, ZBLOCK)[None, None, :]
zmask = zindex < znumel
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
xmask = xindex < xnumel
x2 = xindex
y1 = yindex
z0 = zindex
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[51, 51, 51], strides=[1, 51, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), boundary_check=[0, 1, 2])
tmp1 = tl.load(tl.make_block_ptr(in_ptr0, shape=[51, 51, 51], strides=[51, 1, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), boundary_check=[0, 1, 2])
tmp2 = tmp0 + tmp1
tmp3 = tmp0 + tmp0
tmp4 = tmp2 + tmp3
tl.store(tl.make_block_ptr(out_ptr0, shape=[51, 51, 51], strides=[1, 51, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), tl.broadcast_to(tmp4, [XBLOCK, YBLOCK, ZBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147249
Approved by: https://github.com/jansel
This commit is contained in:
parent
44ee9ca593
commit
1677a31019
|
|
@ -20,6 +20,7 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_GPU,
|
||||
requires_gpu,
|
||||
skip_windows_ci,
|
||||
TRITON_HAS_CPU,
|
||||
)
|
||||
|
|
@ -895,6 +896,34 @@ class CommonTemplate:
|
|||
)
|
||||
self.assertTrue("Min" not in code[0])
|
||||
|
||||
@requires_gpu() # FIXME this test failed on Triton-CPU
|
||||
def test_3d_permute_tiling(self):
|
||||
"""
|
||||
Test 3D tiling with permute.
|
||||
"""
|
||||
|
||||
def foo(x, y, z):
|
||||
dims = [0, 2, 1]
|
||||
a = x.permute(dims=dims) + y
|
||||
b = (z + y).permute(dims=dims)
|
||||
return a + b
|
||||
|
||||
inps = (torch.rand((51, 51, 51), device=self.device, dtype=torch.float32),) * 3
|
||||
result, (code,) = run_and_compare(
|
||||
self,
|
||||
foo,
|
||||
*inps,
|
||||
expected_num_triton_kernels=1,
|
||||
expected_num_block_pointers=3,
|
||||
config_patches={
|
||||
"triton.max_tiles": 3,
|
||||
"triton.prefer_nd_tiling": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Check for 3D tiling
|
||||
self.assertIn("ZBLOCK", code)
|
||||
|
||||
|
||||
@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
|
||||
@config.patch(cpu_backend="triton")
|
||||
|
|
|
|||
|
|
@ -2011,8 +2011,8 @@ class SIMDScheduling(BaseScheduling):
|
|||
def convert_tiling_to_3d(
|
||||
tiling0: dict[str, sympy.Expr], tiling1: dict[str, sympy.Expr]
|
||||
) -> Optional[dict[str, sympy.Expr]]:
|
||||
a0, a1 = tiling0["x"], tiling0["y"]
|
||||
b0, b1 = tiling1["x"], tiling1["y"]
|
||||
a0, a1 = tiling0["x"], tiling0.get("y", 1)
|
||||
b0, b1 = tiling1["x"], tiling1.get("y", 1)
|
||||
if V.graph.sizevars.size_hint(a1 - b1) == 0:
|
||||
return None
|
||||
if V.graph.sizevars.size_hint(a1 - b1) < 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user