[aoti] fix corner case in unbacked replacements for atomically_apply_size_hint (#153768)

## PR
There are a few cases that my previous PR (#153220) didn't cover.
1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`.
2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this.
3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this.

## Device assertion msg

```
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
...
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
```

## Autotuning code setup
This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`).

It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).
```
in_buf0 = generate_example_value(size=(u1 + s0, 256))   # concrete size is (17900, 256)
in_buf1 = generate_example_value(size=(u0, 10))         # concrete size is (8192, 10)
...
out_buf = generate_example_value(size=(u1 + s0, 266))   # concrete size is (17900, 256+10)
triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...)
```

If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.
- `tmp6`  makes sure we're only loading with the `xindex` from 256 to 264.
- `xmask` makes sure we're only loading with the `xindex` within `xnumel`.
- `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`.

The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load.
```
def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    xmask = xindex < xnumel
    x0 = (xindex % 264)
    x1 = xindex // 264
    ...
    tmp6 = x0 >= tl.full([1], value=256)
    tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask)
    # device assertion is thrown here
    tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153768
Approved by: https://github.com/jingsh
This commit is contained in:
Colin Peppler 2025-05-20 14:29:15 -07:00 committed by PyTorch MergeBot
parent a264af8c71
commit fe285b9560
2 changed files with 150 additions and 29 deletions

View File

@ -1337,13 +1337,7 @@ class AOTInductorTestsTemplate:
unbacked_add_expr = backed + unbacked
repeated = x.repeat(unbacked_add_expr, 1)
return torch.cat(
[
repeated,
index_select,
],
dim=1,
)
return torch.cat([repeated, index_select], dim=1)
example_inputs = (
torch.ones(64, dtype=torch.int64, device=self.device),
@ -1365,6 +1359,113 @@ class AOTInductorTestsTemplate:
}
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
def test_size_with_unbacked_add_expr_transitive(self):
# Edge case with torch._check(expr1, expr2) + torch._check(expr2, unbacked).
# When generating example input sizes for autotuning, it should coalesce
# expr1, expr2, unbacked into a single size.
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Repro(torch.nn.Module):
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
index = torch.repeat_interleave(values, repeats)
index_select = torch.index_select(embeddings, 0, index)
u0, u1 = lst.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
backed0, backed1 = z.size(0), z.size(1)
repeated0 = y.repeat(backed0 + u0, 1)
repeated1 = x.repeat(backed1 + u1, 1)
out1 = torch.empty_like(repeated1)
add_kernel[(out1.numel(),)](
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
)
# Implicitly add torch._check(expr2, unbacked)
cat = torch.cat([out1, index_select], dim=1)
add = repeated0 + repeated1
# Explicitly add torch._check(expr1, expr2)
torch._check(repeated0.size(0) == out1.size(0))
return cat, add
example_inputs = (
torch.ones(64, dtype=torch.int64, device=self.device),
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
torch.ones(758, 758, dtype=torch.int64, device=self.device),
torch.tensor([10, 10], dtype=torch.int32, device=self.device),
)
spec = {
"values": (Dim.DYNAMIC,),
"repeats": (Dim.DYNAMIC,),
"mask": (Dim.DYNAMIC,),
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
"x": (Dim.DYNAMIC, Dim.STATIC),
"y": (Dim.DYNAMIC, Dim.STATIC),
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
"lst": (Dim.STATIC,),
}
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
@config.patch({"unbacked_symint_fallback": 128})
def test_size_with_unbacked_add_and_mul_expr(self):
# Edge case with torch._check(add_expr, mul_expr). When generating example
# input sizes for autotuning, make sure they coalesce into a single size.
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Repro(torch.nn.Module):
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
u0, u1, u2 = lst.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
torch._check_is_size(u2)
backed = z.size(0)
backed1 = z.size(1)
unbacked_add_expr = backed + u0
unbacked_mul_expr = backed1 + (u1 * u2)
repeated0 = x.repeat(unbacked_add_expr, 1)
repeated1 = y.repeat(unbacked_mul_expr, 1)
out0 = torch.empty_like(repeated0)
out1 = torch.empty_like(repeated1)
add_kernel[(out0.numel(),)](
repeated0, repeated0, out0, out0.numel(), BLOCK_SIZE=2
)
add_kernel[(out1.numel(),)](
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
)
return torch.cat([out1, out0], dim=1)
example_inputs = (
torch.ones(64, dtype=torch.int64, device=self.device),
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
torch.ones(758, 758, dtype=torch.int64, device=self.device),
torch.tensor([10, 5, 2], dtype=torch.int32, device=self.device),
)
spec = {
"values": (Dim.DYNAMIC,),
"repeats": (Dim.DYNAMIC,),
"mask": (Dim.DYNAMIC,),
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
"x": (Dim.DYNAMIC, Dim.STATIC),
"y": (Dim.DYNAMIC, Dim.STATIC),
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
"lst": (Dim.STATIC,),
}
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
def test_fallback_kernel_with_symexpr_output(self):
if self.device != GPU_TYPE:

View File

@ -8,7 +8,11 @@ from typing import Any, Callable, cast, Optional, Union
import sympy
from sympy import Expr
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv
from torch.fx.experimental.symbolic_shapes import (
free_unbacked_symbols,
has_free_unbacked_symbols,
ShapeEnv,
)
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import symbol_is_type, SymT
@ -62,7 +66,7 @@ class SizeVarAllocator:
self.shape_env = shape_env
self.var_to_val = self.shape_env.var_to_val
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
self.unbacked_replacements: dict[Expr, Expr] = {}
self.unbacked_replacements: Optional[dict[Expr, Expr]] = None
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
# The basic idea is if we have some complicated sympy expression
# f(s0), we may choose to precompute it on the host and then replace
@ -639,7 +643,7 @@ class SizeVarAllocator:
)
return strides
def _get_unbacked_replacements(self, expr: Expr) -> dict[Expr, Expr]:
def _get_unbacked_replacements(self) -> dict[Expr, Expr]:
"""
This helps with covering unbacked symint cases where you may have two
expressions: s0 + u0 and u1. And s0 + u0 is known to be equal to u1
@ -649,33 +653,49 @@ class SizeVarAllocator:
hint for both s0 + u0 and u1, but it first needs to know they are equal.
Then it can substitute s0 + u0 for u1.
"""
if expr in self.unbacked_replacements:
return self.unbacked_replacements[expr]
if self.unbacked_replacements is not None:
return self.unbacked_replacements
runtime_asserts = itertools.chain.from_iterable(
self.shape_env.deferred_runtime_asserts.get(u, [])
for u in free_unbacked_symbols(expr)
)
equalities = (
assertion.expr
for assertion in runtime_asserts
if isinstance(assertion.expr, sympy.Equality)
)
replacements = {eq.rhs: eq.lhs for eq in equalities}
self.unbacked_replacements = {}
for assertions in self.shape_env.deferred_runtime_asserts.values():
for assertion in assertions:
if not isinstance(assertion.expr, sympy.Equality):
continue
self.unbacked_replacements[expr] = replacements
return replacements
lhs, rhs = assertion.expr.lhs, assertion.expr.rhs
l2r = lhs.compare(rhs) == 1 # see sympy.Basic.compare
src = lhs if l2r else rhs
dst = rhs if l2r else lhs
existing_replacement = self.unbacked_replacements.get(src, None)
if existing_replacement and isinstance(
existing_replacement, sympy.Symbol
):
# Prefer to keep replacements with symbols.
continue
self.unbacked_replacements[src] = dst
return self.unbacked_replacements
@functools.lru_cache # noqa: B019
def _sub_unbacked_exprs(self, expr: Expr) -> Expr:
# it's fine to cache this fn since self is a singleton
replacements = self._get_unbacked_replacements()
while True:
new_expr = expr.subs(replacements)
if new_expr == expr:
return new_expr
expr = sympy.factor(new_expr)
def atomically_apply_size_hint(
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
) -> Union[Expr, int]:
if isinstance(expr, int):
if isinstance(expr, (int, sympy.Integer)):
return int(expr)
# Make sure to substitute with the factored version
# e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0
unbacked_replacements = self._get_unbacked_replacements(expr)
expr = sympy.factor(expr).subs(unbacked_replacements)
if has_free_unbacked_symbols(expr):
# Make sure to substitute with the factored version
# e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0
expr = self._sub_unbacked_exprs(sympy.factor(expr))
# For multiple expressions that depend on an unbacked symint,
# we want to compute them consistently for a size hint we have chosen.