[inductor] skip non-trivial tiling if unbacked symints are present (#150225)

Take two of https://github.com/pytorch/pytorch/pull/149994.

This time we just skip `convert_tiling_to_3d` and `candidate_tilings` if there exists unbacked symints.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150225
Approved by: https://github.com/eellison
This commit is contained in:
Colin Peppler 2025-04-01 12:53:29 -07:00 committed by PyTorch MergeBot
parent 03c879d59b
commit a8f6b40e36
3 changed files with 49 additions and 6 deletions

View File

@ -931,6 +931,40 @@ class CommonTemplate:
# Check for 3D tiling
self.assertIn("ZBLOCK", code)
@torch._dynamo.config.patch({"capture_scalar_outputs": True})
@parametrize("num_tile_candidates", (1, 2))
def test_unbacked_size_on_non_contig_dim(self, num_tile_candidates: int):
# NUM_REPEAT should determine # of candidate_tilings.
NUM_REPEAT = 2 if num_tile_candidates == 2 else 8
def foo(x, length):
unbacked = length.item()
torch._check_is_size(unbacked)
repeated = x.repeat(1, unbacked, NUM_REPEAT)
# permute creates split in middle with unbacked symint is the first range
# ranges: [33*unbacked, NUM_REPEAT, 64]
permute120 = repeated.permute([1, 2, 0])
return permute120.cos()
inps = (
torch.rand((64, 33, 1), device=self.device, dtype=torch.float32),
torch.scalar_tensor(16, device=self.device, dtype=torch.int32),
)
with torch._dynamo.config.patch({"capture_scalar_outputs": True}):
run_and_compare(
self,
foo,
*inps,
expected_num_triton_kernels=1,
expected_num_block_pointers=0,
config_patches={
"triton.max_tiles": 3,
"triton.prefer_nd_tiling": True,
},
)
# block_ptr advancements should also be deferrered conditional
# on the associated buffer not being removed
# in this case the bernoulli operation is fused with the following sum

View File

@ -18,6 +18,7 @@ import sympy
import torch
import torch._logging
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.immutable_collections import immutable_dict
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
@ -1764,7 +1765,11 @@ class SIMDScheduling(BaseScheduling):
return tilings
pointwise_ranges, reduction_ranges = node.get_ranges()
if len(pointwise_ranges) <= 1 and len(reduction_ranges) <= 1:
if (
len(pointwise_ranges) <= 1
and len(reduction_ranges) <= 1
or free_unbacked_symbols(pointwise_ranges + reduction_ranges)
):
return []
# Tile either pointwise or reduction dims.
@ -2013,7 +2018,11 @@ class SIMDScheduling(BaseScheduling):
) -> Optional[dict[str, sympy.Expr]]:
a0, a1 = tiling0["x"], tiling0.get("y", 1)
b0, b1 = tiling1["x"], tiling1.get("y", 1)
if V.graph.sizevars.size_hint(a1 - b1) == 0:
if (
free_unbacked_symbols([a1, b1])
or V.graph.sizevars.size_hint(a1 - b1) == 0
):
return None
if V.graph.sizevars.size_hint(a1 - b1) < 0:
# swap so a0 is bigger

View File

@ -40,6 +40,7 @@ from torch._prims_common import (
Number,
)
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing
@ -1088,8 +1089,6 @@ def trunc(x):
@register_lowering(aten.expand, type_promotion_kind=None)
def expand(x, sizes):
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
(x,) = promote_constants([x])
if isinstance(x, ir.BaseConstant):
return ExpandView.create(x, tuple(sizes))
@ -1166,8 +1165,9 @@ def repeat(x, repeats):
return x_loader(index)
old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size))
if old_size_product > 0:
# maybe realize the input
if old_size_product > 0 and not free_unbacked_symbols(new_size):
# maybe realize the input but skip for unbacked symints since it'll
# choke on the size hint.
x.mark_reuse(
V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product
)