mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Fix 3d tiling (#141709)
Fixes #141121 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141709 Approved by: https://github.com/eellison
This commit is contained in:
parent
90f19fee8a
commit
b2fe1b9409
|
|
@ -1495,6 +1495,33 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
|
|||
device_stats2["active.all.peak"] <= device_stats["active.all.peak"]
|
||||
)
|
||||
|
||||
@config.patch(
|
||||
{
|
||||
"triton.prefer_nd_tiling": True,
|
||||
"triton.max_tiles": 3,
|
||||
}
|
||||
)
|
||||
def test_3d_tiling(self):
|
||||
full_size, view_size, num_block_pointers, num_tiles = (
|
||||
(5, 5, 5, 5, 5),
|
||||
(3, 3, 5, 3, 5),
|
||||
1,
|
||||
2,
|
||||
)
|
||||
GPU_TYPE = "cuda"
|
||||
|
||||
def get_input() -> torch.Tensor:
|
||||
device = torch.device(GPU_TYPE)
|
||||
full = torch.randn(full_size).to(device)
|
||||
return torch.as_strided(full, view_size, full.stride())
|
||||
|
||||
a, b = get_input(), get_input()
|
||||
|
||||
opt_fn = torch.compile(functools.partial(torch.add))
|
||||
result, (code,) = run_and_get_code(opt_fn, a, b)
|
||||
self.assertEqual(result, a + b)
|
||||
self.assertIn("znumel", code)
|
||||
|
||||
def test_repeated_masked_load(self):
|
||||
target_size = (8, 2)
|
||||
mem_eff_temporal_upsampling_interp_chunks = 2
|
||||
|
|
|
|||
|
|
@ -80,8 +80,8 @@ op0.sizes = ([256], [])
|
|||
arg0_1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
|
||||
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
|
||||
class op0_loop_body:
|
||||
var_ranges = {z0: 256}
|
||||
index0 = z0
|
||||
var_ranges = {p0: 256}
|
||||
index0 = p0
|
||||
def body(self, ops):
|
||||
get_index = self.get_index('index0')
|
||||
load = ops.load('arg0_1', get_index)
|
||||
|
|
@ -107,8 +107,8 @@ op1.sizes = ([256], [])
|
|||
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
|
||||
buf1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
|
||||
class op1_loop_body:
|
||||
var_ranges = {z0: 256}
|
||||
index0 = z0
|
||||
var_ranges = {p0: 256}
|
||||
index0 = p0
|
||||
def body(self, ops):
|
||||
get_index = self.get_index('index0')
|
||||
load = ops.load('buf0', get_index)
|
||||
|
|
@ -161,8 +161,8 @@ op0.sizes = ([256], [])
|
|||
arg0_1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
|
||||
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
|
||||
class op0_loop_body:
|
||||
var_ranges = {z0: 256}
|
||||
index0 = z0
|
||||
var_ranges = {p0: 256}
|
||||
index0 = p0
|
||||
def body(self, ops):
|
||||
get_index = self.get_index('index0')
|
||||
load = ops.load('arg0_1', get_index)
|
||||
|
|
@ -187,8 +187,8 @@ op1.sizes = ([256], [])
|
|||
buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
|
||||
buf1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
|
||||
class op1_loop_body:
|
||||
var_ranges = {z0: 256}
|
||||
index0 = z0
|
||||
var_ranges = {p0: 256}
|
||||
index0 = p0
|
||||
def body(self, ops):
|
||||
get_index = self.get_index('index0')
|
||||
load = ops.load('buf0', get_index)
|
||||
|
|
|
|||
|
|
@ -117,10 +117,10 @@ class ImplDetailTest(TestCase):
|
|||
snode = SchedulerNode(V.graph.scheduler, buf)
|
||||
snode.apply_new_loop_order([1, 0])
|
||||
prefix1 = self._get_snode_body_sym_prefix(snode)
|
||||
self.assertTrue(prefix1 == "z")
|
||||
self.assertTrue(prefix1 == "p")
|
||||
snode.apply_new_loop_order([1, 0])
|
||||
prefix2 = self._get_snode_body_sym_prefix(snode)
|
||||
self.assertTrue(prefix2 == "z")
|
||||
self.assertTrue(prefix2 == "p")
|
||||
|
||||
def test_reorder_and_merge_loops(self):
|
||||
sizes = (1024, 2048)
|
||||
|
|
@ -163,7 +163,7 @@ class ImplDetailTest(TestCase):
|
|||
_, body = buf.simplify_and_reorder()
|
||||
new_body = body.reorder_iter_loops([1, 2, 3, 0])
|
||||
|
||||
z0, z1, z2, z3 = (sympy_index_symbol(f"z{i}") for i in range(4))
|
||||
z0, z1, z2, z3 = (sympy_index_symbol(f"p{i}") for i in range(4))
|
||||
self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49})
|
||||
self.assertEqual(
|
||||
body.indexing_exprs["index0"],
|
||||
|
|
|
|||
|
|
@ -155,14 +155,14 @@ class TritonSymbols:
|
|||
|
||||
block_offsets = {
|
||||
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
|
||||
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
|
||||
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
|
||||
}
|
||||
|
||||
block_sizes = {
|
||||
symt: sympy.Symbol(
|
||||
f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
|
||||
)
|
||||
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
|
||||
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -1564,7 +1564,7 @@ class TritonKernel(SIMDKernel):
|
|||
else:
|
||||
# var is one of xN, yN or rN
|
||||
assert symbol_is_type(
|
||||
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK)
|
||||
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK)
|
||||
), var.name
|
||||
mask_vars.add(f"{var.name[0]}mask")
|
||||
|
||||
|
|
|
|||
|
|
@ -963,6 +963,10 @@ class triton:
|
|||
dense_indexing = False
|
||||
|
||||
# limit tiling dimensions
|
||||
# - max_tiles=1 disables tiling
|
||||
# - max_tiles=2 is the default
|
||||
# - max_tiles=3 is experimental and may have bugs
|
||||
# higher values are unsupported
|
||||
max_tiles = 2
|
||||
|
||||
# Prefer higher dimensional tilings. This simplifies indexing expressions, making
|
||||
|
|
|
|||
|
|
@ -4155,7 +4155,7 @@ class ComputedBuffer(OperationBuffer):
|
|||
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
||||
iter_ranges,
|
||||
reduce_ranges,
|
||||
prefix="z",
|
||||
prefix="p",
|
||||
)
|
||||
body = LoopBody(
|
||||
body,
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ class LoopBody:
|
|||
# use the original symbol prefix
|
||||
# Can try to optimize if this is a bottleneck for compilation time
|
||||
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
||||
iter_sizes, reduce_sizes, prefix="z"
|
||||
iter_sizes, reduce_sizes, prefix="p"
|
||||
)
|
||||
new_body2 = LoopBody(
|
||||
new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
||||
|
|
@ -259,7 +259,7 @@ class LoopBody:
|
|||
|
||||
# use the original symbol prefix so we can do multiple round of reordering
|
||||
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
||||
*new_sizes, prefix="z" # type: ignore[arg-type]
|
||||
*new_sizes, prefix="p" # type: ignore[arg-type]
|
||||
)
|
||||
new_body = LoopBody(
|
||||
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class SymT(Enum):
|
|||
# Inductor: iteration domain for blockIdx.x/blockIdx.y
|
||||
XBLOCK = auto()
|
||||
YBLOCK = auto()
|
||||
ZBLOCK = auto()
|
||||
# Inductor: this is used solely for dynamic_reshape_indexer
|
||||
VIEW = auto()
|
||||
# Alternate (non-modular) indexing used in halide kernels
|
||||
|
|
@ -70,6 +71,7 @@ prefix_str = {
|
|||
SymT.TEMPLATE_INDEX: "idx",
|
||||
SymT.XBLOCK: "x",
|
||||
SymT.YBLOCK: "y",
|
||||
SymT.ZBLOCK: "z",
|
||||
SymT.INDIRECT: "indirect", # false aliasing?
|
||||
SymT.VIEW: "view",
|
||||
SymT.HALIDE: "h",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user