From fe285b95606e4ae60203e4b51cc18068a7336ed9 Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Tue, 20 May 2025 14:29:15 -0700 Subject: [PATCH] [aoti] fix corner case in unbacked replacements for atomically_apply_size_hint (#153768) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## PR There are a few cases that my previous PR (#153220) didn't cover. 1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`. 2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this. 3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this. ## Device assertion msg ``` /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ... /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ``` ## Autotuning code setup This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`). It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1). ``` in_buf0 = generate_example_value(size=(u1 + s0, 256)) # concrete size is (17900, 256) in_buf1 = generate_example_value(size=(u0, 10)) # concrete size is (8192, 10) ... out_buf = generate_example_value(size=(u1 + s0, 266)) # concrete size is (17900, 256+10) triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...) ``` If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads. - `tmp6` makes sure we're only loading with the `xindex` from 256 to 264. - `xmask` makes sure we're only loading with the `xindex` within `xnumel`. - `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`. The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load. ``` def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < xnumel x0 = (xindex % 264) x1 = xindex // 264 ... tmp6 = x0 >= tl.full([1], value=256) tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask) # device assertion is thrown here tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153768 Approved by: https://github.com/jingsh --- test/inductor/test_aot_inductor.py | 115 +++++++++++++++++++++++++++-- torch/_inductor/sizevars.py | 64 ++++++++++------ 2 files changed, 150 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index c9d160d659f..9ab64f880b3 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1337,13 +1337,7 @@ class AOTInductorTestsTemplate: unbacked_add_expr = backed + unbacked repeated = x.repeat(unbacked_add_expr, 1) - return torch.cat( - [ - repeated, - index_select, - ], - dim=1, - ) + return torch.cat([repeated, index_select], dim=1) example_inputs = ( torch.ones(64, dtype=torch.int64, device=self.device), @@ -1365,6 +1359,113 @@ class AOTInductorTestsTemplate: } self.check_model(Repro(), example_inputs, dynamic_shapes=spec) + def test_size_with_unbacked_add_expr_transitive(self): + # Edge case with torch._check(expr1, expr2) + torch._check(expr2, unbacked). + # When generating example input sizes for autotuning, it should coalesce + # expr1, expr2, unbacked into a single size. + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Repro(torch.nn.Module): + def forward(self, values, repeats, mask, embeddings, x, y, z, lst): + index = torch.repeat_interleave(values, repeats) + index_select = torch.index_select(embeddings, 0, index) + + u0, u1 = lst.tolist() + torch._check_is_size(u0) + torch._check_is_size(u1) + backed0, backed1 = z.size(0), z.size(1) + + repeated0 = y.repeat(backed0 + u0, 1) + repeated1 = x.repeat(backed1 + u1, 1) + out1 = torch.empty_like(repeated1) + add_kernel[(out1.numel(),)]( + repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2 + ) + + # Implicitly add torch._check(expr2, unbacked) + cat = torch.cat([out1, index_select], dim=1) + add = repeated0 + repeated1 + + # Explicitly add torch._check(expr1, expr2) + torch._check(repeated0.size(0) == out1.size(0)) + return cat, add + + example_inputs = ( + torch.ones(64, dtype=torch.int64, device=self.device), + torch.ones(64, dtype=torch.int64, device=self.device) * 24, + torch.ones((768,), dtype=torch.int64, device=self.device).bool(), + torch.randn((401, 8), dtype=torch.bfloat16, device=self.device), + torch.randn((2, 256), dtype=torch.bfloat16, device=self.device), + torch.randn((2, 256), dtype=torch.bfloat16, device=self.device), + torch.ones(758, 758, dtype=torch.int64, device=self.device), + torch.tensor([10, 10], dtype=torch.int32, device=self.device), + ) + spec = { + "values": (Dim.DYNAMIC,), + "repeats": (Dim.DYNAMIC,), + "mask": (Dim.DYNAMIC,), + "embeddings": (Dim.DYNAMIC, Dim.STATIC), + "x": (Dim.DYNAMIC, Dim.STATIC), + "y": (Dim.DYNAMIC, Dim.STATIC), + "z": (Dim.DYNAMIC, Dim.DYNAMIC), + "lst": (Dim.STATIC,), + } + self.check_model(Repro(), example_inputs, dynamic_shapes=spec) + + @config.patch({"unbacked_symint_fallback": 128}) + def test_size_with_unbacked_add_and_mul_expr(self): + # Edge case with torch._check(add_expr, mul_expr). When generating example + # input sizes for autotuning, make sure they coalesce into a single size. + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Repro(torch.nn.Module): + def forward(self, values, repeats, mask, embeddings, x, y, z, lst): + u0, u1, u2 = lst.tolist() + torch._check_is_size(u0) + torch._check_is_size(u1) + torch._check_is_size(u2) + backed = z.size(0) + backed1 = z.size(1) + + unbacked_add_expr = backed + u0 + unbacked_mul_expr = backed1 + (u1 * u2) + repeated0 = x.repeat(unbacked_add_expr, 1) + repeated1 = y.repeat(unbacked_mul_expr, 1) + out0 = torch.empty_like(repeated0) + out1 = torch.empty_like(repeated1) + add_kernel[(out0.numel(),)]( + repeated0, repeated0, out0, out0.numel(), BLOCK_SIZE=2 + ) + add_kernel[(out1.numel(),)]( + repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2 + ) + + return torch.cat([out1, out0], dim=1) + + example_inputs = ( + torch.ones(64, dtype=torch.int64, device=self.device), + torch.ones(64, dtype=torch.int64, device=self.device) * 24, + torch.ones((768,), dtype=torch.int64, device=self.device).bool(), + torch.randn((401, 8), dtype=torch.bfloat16, device=self.device), + torch.randn((2, 256), dtype=torch.bfloat16, device=self.device), + torch.randn((2, 256), dtype=torch.bfloat16, device=self.device), + torch.ones(758, 758, dtype=torch.int64, device=self.device), + torch.tensor([10, 5, 2], dtype=torch.int32, device=self.device), + ) + spec = { + "values": (Dim.DYNAMIC,), + "repeats": (Dim.DYNAMIC,), + "mask": (Dim.DYNAMIC,), + "embeddings": (Dim.DYNAMIC, Dim.STATIC), + "x": (Dim.DYNAMIC, Dim.STATIC), + "y": (Dim.DYNAMIC, Dim.STATIC), + "z": (Dim.DYNAMIC, Dim.DYNAMIC), + "lst": (Dim.STATIC,), + } + self.check_model(Repro(), example_inputs, dynamic_shapes=spec) + @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") def test_fallback_kernel_with_symexpr_output(self): if self.device != GPU_TYPE: diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 1251c4ed6b1..dac88b82cc3 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -8,7 +8,11 @@ from typing import Any, Callable, cast, Optional, Union import sympy from sympy import Expr -from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv +from torch.fx.experimental.symbolic_shapes import ( + free_unbacked_symbols, + has_free_unbacked_symbols, + ShapeEnv, +) from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.symbol import symbol_is_type, SymT @@ -62,7 +66,7 @@ class SizeVarAllocator: self.shape_env = shape_env self.var_to_val = self.shape_env.var_to_val self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements - self.unbacked_replacements: dict[Expr, Expr] = {} + self.unbacked_replacements: Optional[dict[Expr, Expr]] = None # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. # The basic idea is if we have some complicated sympy expression # f(s0), we may choose to precompute it on the host and then replace @@ -639,7 +643,7 @@ class SizeVarAllocator: ) return strides - def _get_unbacked_replacements(self, expr: Expr) -> dict[Expr, Expr]: + def _get_unbacked_replacements(self) -> dict[Expr, Expr]: """ This helps with covering unbacked symint cases where you may have two expressions: s0 + u0 and u1. And s0 + u0 is known to be equal to u1 @@ -649,33 +653,49 @@ class SizeVarAllocator: hint for both s0 + u0 and u1, but it first needs to know they are equal. Then it can substitute s0 + u0 for u1. """ - if expr in self.unbacked_replacements: - return self.unbacked_replacements[expr] + if self.unbacked_replacements is not None: + return self.unbacked_replacements - runtime_asserts = itertools.chain.from_iterable( - self.shape_env.deferred_runtime_asserts.get(u, []) - for u in free_unbacked_symbols(expr) - ) - equalities = ( - assertion.expr - for assertion in runtime_asserts - if isinstance(assertion.expr, sympy.Equality) - ) - replacements = {eq.rhs: eq.lhs for eq in equalities} + self.unbacked_replacements = {} + for assertions in self.shape_env.deferred_runtime_asserts.values(): + for assertion in assertions: + if not isinstance(assertion.expr, sympy.Equality): + continue - self.unbacked_replacements[expr] = replacements - return replacements + lhs, rhs = assertion.expr.lhs, assertion.expr.rhs + l2r = lhs.compare(rhs) == 1 # see sympy.Basic.compare + src = lhs if l2r else rhs + dst = rhs if l2r else lhs + + existing_replacement = self.unbacked_replacements.get(src, None) + if existing_replacement and isinstance( + existing_replacement, sympy.Symbol + ): + # Prefer to keep replacements with symbols. + continue + self.unbacked_replacements[src] = dst + return self.unbacked_replacements + + @functools.lru_cache # noqa: B019 + def _sub_unbacked_exprs(self, expr: Expr) -> Expr: + # it's fine to cache this fn since self is a singleton + replacements = self._get_unbacked_replacements() + while True: + new_expr = expr.subs(replacements) + if new_expr == expr: + return new_expr + expr = sympy.factor(new_expr) def atomically_apply_size_hint( self, expr: Union[Expr, int], *, fallback: Optional[int] = None ) -> Union[Expr, int]: - if isinstance(expr, int): + if isinstance(expr, (int, sympy.Integer)): return int(expr) - # Make sure to substitute with the factored version - # e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0 - unbacked_replacements = self._get_unbacked_replacements(expr) - expr = sympy.factor(expr).subs(unbacked_replacements) + if has_free_unbacked_symbols(expr): + # Make sure to substitute with the factored version + # e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0 + expr = self._sub_unbacked_exprs(sympy.factor(expr)) # For multiple expressions that depend on an unbacked symint, # we want to compute them consistently for a size hint we have chosen.