mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aoti] fix corner case in unbacked replacements for atomically_apply_size_hint (#153768)
## PR There are a few cases that my previous PR (#153220) didn't cover. 1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`. 2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this. 3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this. ## Device assertion msg ``` /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ... /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ``` ## Autotuning code setup This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`). It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1). ``` in_buf0 = generate_example_value(size=(u1 + s0, 256)) # concrete size is (17900, 256) in_buf1 = generate_example_value(size=(u0, 10)) # concrete size is (8192, 10) ... out_buf = generate_example_value(size=(u1 + s0, 266)) # concrete size is (17900, 256+10) triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...) ``` If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads. - `tmp6` makes sure we're only loading with the `xindex` from 256 to 264. - `xmask` makes sure we're only loading with the `xindex` within `xnumel`. - `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`. The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load. ``` def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < xnumel x0 = (xindex % 264) x1 = xindex // 264 ... tmp6 = x0 >= tl.full([1], value=256) tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask) # device assertion is thrown here tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153768 Approved by: https://github.com/jingsh
This commit is contained in:
parent
a264af8c71
commit
fe285b9560
|
|
@ -1337,13 +1337,7 @@ class AOTInductorTestsTemplate:
|
|||
|
||||
unbacked_add_expr = backed + unbacked
|
||||
repeated = x.repeat(unbacked_add_expr, 1)
|
||||
return torch.cat(
|
||||
[
|
||||
repeated,
|
||||
index_select,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return torch.cat([repeated, index_select], dim=1)
|
||||
|
||||
example_inputs = (
|
||||
torch.ones(64, dtype=torch.int64, device=self.device),
|
||||
|
|
@ -1365,6 +1359,113 @@ class AOTInductorTestsTemplate:
|
|||
}
|
||||
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
|
||||
|
||||
def test_size_with_unbacked_add_expr_transitive(self):
|
||||
# Edge case with torch._check(expr1, expr2) + torch._check(expr2, unbacked).
|
||||
# When generating example input sizes for autotuning, it should coalesce
|
||||
# expr1, expr2, unbacked into a single size.
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Repro(torch.nn.Module):
|
||||
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
|
||||
index = torch.repeat_interleave(values, repeats)
|
||||
index_select = torch.index_select(embeddings, 0, index)
|
||||
|
||||
u0, u1 = lst.tolist()
|
||||
torch._check_is_size(u0)
|
||||
torch._check_is_size(u1)
|
||||
backed0, backed1 = z.size(0), z.size(1)
|
||||
|
||||
repeated0 = y.repeat(backed0 + u0, 1)
|
||||
repeated1 = x.repeat(backed1 + u1, 1)
|
||||
out1 = torch.empty_like(repeated1)
|
||||
add_kernel[(out1.numel(),)](
|
||||
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
|
||||
)
|
||||
|
||||
# Implicitly add torch._check(expr2, unbacked)
|
||||
cat = torch.cat([out1, index_select], dim=1)
|
||||
add = repeated0 + repeated1
|
||||
|
||||
# Explicitly add torch._check(expr1, expr2)
|
||||
torch._check(repeated0.size(0) == out1.size(0))
|
||||
return cat, add
|
||||
|
||||
example_inputs = (
|
||||
torch.ones(64, dtype=torch.int64, device=self.device),
|
||||
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
|
||||
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
|
||||
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
|
||||
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
|
||||
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
|
||||
torch.ones(758, 758, dtype=torch.int64, device=self.device),
|
||||
torch.tensor([10, 10], dtype=torch.int32, device=self.device),
|
||||
)
|
||||
spec = {
|
||||
"values": (Dim.DYNAMIC,),
|
||||
"repeats": (Dim.DYNAMIC,),
|
||||
"mask": (Dim.DYNAMIC,),
|
||||
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
|
||||
"x": (Dim.DYNAMIC, Dim.STATIC),
|
||||
"y": (Dim.DYNAMIC, Dim.STATIC),
|
||||
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
|
||||
"lst": (Dim.STATIC,),
|
||||
}
|
||||
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
|
||||
|
||||
@config.patch({"unbacked_symint_fallback": 128})
|
||||
def test_size_with_unbacked_add_and_mul_expr(self):
|
||||
# Edge case with torch._check(add_expr, mul_expr). When generating example
|
||||
# input sizes for autotuning, make sure they coalesce into a single size.
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Repro(torch.nn.Module):
|
||||
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
|
||||
u0, u1, u2 = lst.tolist()
|
||||
torch._check_is_size(u0)
|
||||
torch._check_is_size(u1)
|
||||
torch._check_is_size(u2)
|
||||
backed = z.size(0)
|
||||
backed1 = z.size(1)
|
||||
|
||||
unbacked_add_expr = backed + u0
|
||||
unbacked_mul_expr = backed1 + (u1 * u2)
|
||||
repeated0 = x.repeat(unbacked_add_expr, 1)
|
||||
repeated1 = y.repeat(unbacked_mul_expr, 1)
|
||||
out0 = torch.empty_like(repeated0)
|
||||
out1 = torch.empty_like(repeated1)
|
||||
add_kernel[(out0.numel(),)](
|
||||
repeated0, repeated0, out0, out0.numel(), BLOCK_SIZE=2
|
||||
)
|
||||
add_kernel[(out1.numel(),)](
|
||||
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
|
||||
)
|
||||
|
||||
return torch.cat([out1, out0], dim=1)
|
||||
|
||||
example_inputs = (
|
||||
torch.ones(64, dtype=torch.int64, device=self.device),
|
||||
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
|
||||
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
|
||||
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
|
||||
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
|
||||
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
|
||||
torch.ones(758, 758, dtype=torch.int64, device=self.device),
|
||||
torch.tensor([10, 5, 2], dtype=torch.int32, device=self.device),
|
||||
)
|
||||
spec = {
|
||||
"values": (Dim.DYNAMIC,),
|
||||
"repeats": (Dim.DYNAMIC,),
|
||||
"mask": (Dim.DYNAMIC,),
|
||||
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
|
||||
"x": (Dim.DYNAMIC, Dim.STATIC),
|
||||
"y": (Dim.DYNAMIC, Dim.STATIC),
|
||||
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
|
||||
"lst": (Dim.STATIC,),
|
||||
}
|
||||
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
|
||||
|
||||
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
||||
def test_fallback_kernel_with_symexpr_output(self):
|
||||
if self.device != GPU_TYPE:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ from typing import Any, Callable, cast, Optional, Union
|
|||
import sympy
|
||||
from sympy import Expr
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
free_unbacked_symbols,
|
||||
has_free_unbacked_symbols,
|
||||
ShapeEnv,
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
|
@ -62,7 +66,7 @@ class SizeVarAllocator:
|
|||
self.shape_env = shape_env
|
||||
self.var_to_val = self.shape_env.var_to_val
|
||||
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
|
||||
self.unbacked_replacements: dict[Expr, Expr] = {}
|
||||
self.unbacked_replacements: Optional[dict[Expr, Expr]] = None
|
||||
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
|
||||
# The basic idea is if we have some complicated sympy expression
|
||||
# f(s0), we may choose to precompute it on the host and then replace
|
||||
|
|
@ -639,7 +643,7 @@ class SizeVarAllocator:
|
|||
)
|
||||
return strides
|
||||
|
||||
def _get_unbacked_replacements(self, expr: Expr) -> dict[Expr, Expr]:
|
||||
def _get_unbacked_replacements(self) -> dict[Expr, Expr]:
|
||||
"""
|
||||
This helps with covering unbacked symint cases where you may have two
|
||||
expressions: s0 + u0 and u1. And s0 + u0 is known to be equal to u1
|
||||
|
|
@ -649,33 +653,49 @@ class SizeVarAllocator:
|
|||
hint for both s0 + u0 and u1, but it first needs to know they are equal.
|
||||
Then it can substitute s0 + u0 for u1.
|
||||
"""
|
||||
if expr in self.unbacked_replacements:
|
||||
return self.unbacked_replacements[expr]
|
||||
if self.unbacked_replacements is not None:
|
||||
return self.unbacked_replacements
|
||||
|
||||
runtime_asserts = itertools.chain.from_iterable(
|
||||
self.shape_env.deferred_runtime_asserts.get(u, [])
|
||||
for u in free_unbacked_symbols(expr)
|
||||
)
|
||||
equalities = (
|
||||
assertion.expr
|
||||
for assertion in runtime_asserts
|
||||
if isinstance(assertion.expr, sympy.Equality)
|
||||
)
|
||||
replacements = {eq.rhs: eq.lhs for eq in equalities}
|
||||
self.unbacked_replacements = {}
|
||||
for assertions in self.shape_env.deferred_runtime_asserts.values():
|
||||
for assertion in assertions:
|
||||
if not isinstance(assertion.expr, sympy.Equality):
|
||||
continue
|
||||
|
||||
self.unbacked_replacements[expr] = replacements
|
||||
return replacements
|
||||
lhs, rhs = assertion.expr.lhs, assertion.expr.rhs
|
||||
l2r = lhs.compare(rhs) == 1 # see sympy.Basic.compare
|
||||
src = lhs if l2r else rhs
|
||||
dst = rhs if l2r else lhs
|
||||
|
||||
existing_replacement = self.unbacked_replacements.get(src, None)
|
||||
if existing_replacement and isinstance(
|
||||
existing_replacement, sympy.Symbol
|
||||
):
|
||||
# Prefer to keep replacements with symbols.
|
||||
continue
|
||||
self.unbacked_replacements[src] = dst
|
||||
return self.unbacked_replacements
|
||||
|
||||
@functools.lru_cache # noqa: B019
|
||||
def _sub_unbacked_exprs(self, expr: Expr) -> Expr:
|
||||
# it's fine to cache this fn since self is a singleton
|
||||
replacements = self._get_unbacked_replacements()
|
||||
while True:
|
||||
new_expr = expr.subs(replacements)
|
||||
if new_expr == expr:
|
||||
return new_expr
|
||||
expr = sympy.factor(new_expr)
|
||||
|
||||
def atomically_apply_size_hint(
|
||||
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
|
||||
) -> Union[Expr, int]:
|
||||
if isinstance(expr, int):
|
||||
if isinstance(expr, (int, sympy.Integer)):
|
||||
return int(expr)
|
||||
|
||||
if has_free_unbacked_symbols(expr):
|
||||
# Make sure to substitute with the factored version
|
||||
# e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0
|
||||
unbacked_replacements = self._get_unbacked_replacements(expr)
|
||||
expr = sympy.factor(expr).subs(unbacked_replacements)
|
||||
expr = self._sub_unbacked_exprs(sympy.factor(expr))
|
||||
|
||||
# For multiple expressions that depend on an unbacked symint,
|
||||
# we want to compute them consistently for a size hint we have chosen.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user