mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] skip non-trivial tiling if unbacked symints are present (#150225)
Take two of https://github.com/pytorch/pytorch/pull/149994. This time we just skip `convert_tiling_to_3d` and `candidate_tilings` if there exists unbacked symints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150225 Approved by: https://github.com/eellison
This commit is contained in:
parent
03c879d59b
commit
a8f6b40e36
|
|
@ -931,6 +931,40 @@ class CommonTemplate:
|
|||
# Check for 3D tiling
|
||||
self.assertIn("ZBLOCK", code)
|
||||
|
||||
@torch._dynamo.config.patch({"capture_scalar_outputs": True})
|
||||
@parametrize("num_tile_candidates", (1, 2))
|
||||
def test_unbacked_size_on_non_contig_dim(self, num_tile_candidates: int):
|
||||
# NUM_REPEAT should determine # of candidate_tilings.
|
||||
NUM_REPEAT = 2 if num_tile_candidates == 2 else 8
|
||||
|
||||
def foo(x, length):
|
||||
unbacked = length.item()
|
||||
torch._check_is_size(unbacked)
|
||||
|
||||
repeated = x.repeat(1, unbacked, NUM_REPEAT)
|
||||
# permute creates split in middle with unbacked symint is the first range
|
||||
# ranges: [33*unbacked, NUM_REPEAT, 64]
|
||||
permute120 = repeated.permute([1, 2, 0])
|
||||
return permute120.cos()
|
||||
|
||||
inps = (
|
||||
torch.rand((64, 33, 1), device=self.device, dtype=torch.float32),
|
||||
torch.scalar_tensor(16, device=self.device, dtype=torch.int32),
|
||||
)
|
||||
|
||||
with torch._dynamo.config.patch({"capture_scalar_outputs": True}):
|
||||
run_and_compare(
|
||||
self,
|
||||
foo,
|
||||
*inps,
|
||||
expected_num_triton_kernels=1,
|
||||
expected_num_block_pointers=0,
|
||||
config_patches={
|
||||
"triton.max_tiles": 3,
|
||||
"triton.prefer_nd_tiling": True,
|
||||
},
|
||||
)
|
||||
|
||||
# block_ptr advancements should also be deferrered conditional
|
||||
# on the associated buffer not being removed
|
||||
# in this case the bernoulli operation is fused with the following sum
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import sympy
|
|||
|
||||
import torch
|
||||
import torch._logging
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.fx.immutable_collections import immutable_dict
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
|
||||
|
|
@ -1764,7 +1765,11 @@ class SIMDScheduling(BaseScheduling):
|
|||
return tilings
|
||||
|
||||
pointwise_ranges, reduction_ranges = node.get_ranges()
|
||||
if len(pointwise_ranges) <= 1 and len(reduction_ranges) <= 1:
|
||||
if (
|
||||
len(pointwise_ranges) <= 1
|
||||
and len(reduction_ranges) <= 1
|
||||
or free_unbacked_symbols(pointwise_ranges + reduction_ranges)
|
||||
):
|
||||
return []
|
||||
|
||||
# Tile either pointwise or reduction dims.
|
||||
|
|
@ -2013,7 +2018,11 @@ class SIMDScheduling(BaseScheduling):
|
|||
) -> Optional[dict[str, sympy.Expr]]:
|
||||
a0, a1 = tiling0["x"], tiling0.get("y", 1)
|
||||
b0, b1 = tiling1["x"], tiling1.get("y", 1)
|
||||
if V.graph.sizevars.size_hint(a1 - b1) == 0:
|
||||
|
||||
if (
|
||||
free_unbacked_symbols([a1, b1])
|
||||
or V.graph.sizevars.size_hint(a1 - b1) == 0
|
||||
):
|
||||
return None
|
||||
if V.graph.sizevars.size_hint(a1 - b1) < 0:
|
||||
# swap so a0 is bigger
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from torch._prims_common import (
|
|||
Number,
|
||||
)
|
||||
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing
|
||||
|
||||
|
|
@ -1088,8 +1089,6 @@ def trunc(x):
|
|||
|
||||
@register_lowering(aten.expand, type_promotion_kind=None)
|
||||
def expand(x, sizes):
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
|
||||
(x,) = promote_constants([x])
|
||||
if isinstance(x, ir.BaseConstant):
|
||||
return ExpandView.create(x, tuple(sizes))
|
||||
|
|
@ -1166,8 +1165,9 @@ def repeat(x, repeats):
|
|||
return x_loader(index)
|
||||
|
||||
old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size))
|
||||
if old_size_product > 0:
|
||||
# maybe realize the input
|
||||
if old_size_product > 0 and not free_unbacked_symbols(new_size):
|
||||
# maybe realize the input but skip for unbacked symints since it'll
|
||||
# choke on the size hint.
|
||||
x.mark_reuse(
|
||||
V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user