[inductor] Fix 3d tiling (#141709)

Fixes #141121

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141709
Approved by: https://github.com/eellison
This commit is contained in:
Jason Ansel 2024-11-29 19:36:05 -08:00 committed by PyTorch MergeBot
parent 90f19fee8a
commit b2fe1b9409
8 changed files with 50 additions and 17 deletions

View File

@ -1495,6 +1495,33 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
device_stats2["active.all.peak"] <= device_stats["active.all.peak"]
)
@config.patch(
{
"triton.prefer_nd_tiling": True,
"triton.max_tiles": 3,
}
)
def test_3d_tiling(self):
full_size, view_size, num_block_pointers, num_tiles = (
(5, 5, 5, 5, 5),
(3, 3, 5, 3, 5),
1,
2,
)
GPU_TYPE = "cuda"
def get_input() -> torch.Tensor:
device = torch.device(GPU_TYPE)
full = torch.randn(full_size).to(device)
return torch.as_strided(full, view_size, full.stride())
a, b = get_input(), get_input()
opt_fn = torch.compile(functools.partial(torch.add))
result, (code,) = run_and_get_code(opt_fn, a, b)
self.assertEqual(result, a + b)
self.assertIn("znumel", code)
def test_repeated_masked_load(self):
target_size = (8, 2)
mem_eff_temporal_upsampling_interp_chunks = 2

View File

@ -80,8 +80,8 @@ op0.sizes = ([256], [])
arg0_1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
class op0_loop_body:
var_ranges = {z0: 256}
index0 = z0
var_ranges = {p0: 256}
index0 = p0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('arg0_1', get_index)
@ -107,8 +107,8 @@ op1.sizes = ([256], [])
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
buf1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
class op1_loop_body:
var_ranges = {z0: 256}
index0 = z0
var_ranges = {p0: 256}
index0 = p0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf0', get_index)
@ -161,8 +161,8 @@ op0.sizes = ([256], [])
arg0_1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
class op0_loop_body:
var_ranges = {z0: 256}
index0 = z0
var_ranges = {p0: 256}
index0 = p0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('arg0_1', get_index)
@ -187,8 +187,8 @@ op1.sizes = ([256], [])
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
buf1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
class op1_loop_body:
var_ranges = {z0: 256}
index0 = z0
var_ranges = {p0: 256}
index0 = p0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf0', get_index)

View File

@ -117,10 +117,10 @@ class ImplDetailTest(TestCase):
snode = SchedulerNode(V.graph.scheduler, buf)
snode.apply_new_loop_order([1, 0])
prefix1 = self._get_snode_body_sym_prefix(snode)
self.assertTrue(prefix1 == "z")
self.assertTrue(prefix1 == "p")
snode.apply_new_loop_order([1, 0])
prefix2 = self._get_snode_body_sym_prefix(snode)
self.assertTrue(prefix2 == "z")
self.assertTrue(prefix2 == "p")
def test_reorder_and_merge_loops(self):
sizes = (1024, 2048)
@ -163,7 +163,7 @@ class ImplDetailTest(TestCase):
_, body = buf.simplify_and_reorder()
new_body = body.reorder_iter_loops([1, 2, 3, 0])
z0, z1, z2, z3 = (sympy_index_symbol(f"z{i}") for i in range(4))
z0, z1, z2, z3 = (sympy_index_symbol(f"p{i}") for i in range(4))
self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49})
self.assertEqual(
body.indexing_exprs["index0"],

View File

@ -155,14 +155,14 @@ class TritonSymbols:
block_offsets = {
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
}
block_sizes = {
symt: sympy.Symbol(
f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
}
@classmethod
@ -1564,7 +1564,7 @@ class TritonKernel(SIMDKernel):
else:
# var is one of xN, yN or rN
assert symbol_is_type(
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK)
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK)
), var.name
mask_vars.add(f"{var.name[0]}mask")

View File

@ -963,6 +963,10 @@ class triton:
dense_indexing = False
# limit tiling dimensions
# - max_tiles=1 disables tiling
# - max_tiles=2 is the default
# - max_tiles=3 is experimental and may have bugs
# higher values are unsupported
max_tiles = 2
# Prefer higher dimensional tilings. This simplifies indexing expressions, making

View File

@ -4155,7 +4155,7 @@ class ComputedBuffer(OperationBuffer):
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
iter_ranges,
reduce_ranges,
prefix="z",
prefix="p",
)
body = LoopBody(
body,

View File

@ -215,7 +215,7 @@ class LoopBody:
# use the original symbol prefix
# Can try to optimize if this is a bottleneck for compilation time
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
iter_sizes, reduce_sizes, prefix="z"
iter_sizes, reduce_sizes, prefix="p"
)
new_body2 = LoopBody(
new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
@ -259,7 +259,7 @@ class LoopBody:
# use the original symbol prefix so we can do multiple round of reordering
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
*new_sizes, prefix="z" # type: ignore[arg-type]
*new_sizes, prefix="p" # type: ignore[arg-type]
)
new_body = LoopBody(
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2

View File

@ -47,6 +47,7 @@ class SymT(Enum):
# Inductor: iteration domain for blockIdx.x/blockIdx.y
XBLOCK = auto()
YBLOCK = auto()
ZBLOCK = auto()
# Inductor: this is used solely for dynamic_reshape_indexer
VIEW = auto()
# Alternate (non-modular) indexing used in halide kernels
@ -70,6 +71,7 @@ prefix_str = {
SymT.TEMPLATE_INDEX: "idx",
SymT.XBLOCK: "x",
SymT.YBLOCK: "y",
SymT.ZBLOCK: "z",
SymT.INDIRECT: "indirect", # false aliasing?
SymT.VIEW: "view",
SymT.HALIDE: "h",