Unify dynamic shapes APIs naming 2 (expect_true and check) attempt2 (#156518)

Summary:
The functions guard_lt, guard_equals, and guard_leq work similarly to torch.check and expect_true, but they operate on SymPy expressions. Notably, guard_equals applies local replacements before comparison, which might be better extracted into a separate function.

This pull request standardizes naming conventions to match symbolic_shapes.py. Specifically,
-  it introduces size_vars.expect_true and size_vars.check.
- guard_lt becomes check_lt
- guard_leq becomes check_leq
- guard_equals becomes check_equals

I am also seeing a couple of wrong usages !! that i will fix  in the next PR

Test Plan:
OSS and cont

Rollback Plan:

Differential Revision: D77054177

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156518
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka 2025-06-24 21:01:38 +00:00 committed by PyTorch MergeBot
parent dfef1e4408
commit 26f7ca3972
11 changed files with 80 additions and 86 deletions

View File

@ -3735,7 +3735,7 @@ class TilingSelect:
call_ranges[tiling_indice], fallback=0
)
if call_range < factor_lowp:
V.graph.sizevars.guard_lt(call_range, factor_lowp) # type: ignore[arg-type]
V.graph.sizevars.check_lt(call_range, factor_lowp) # type: ignore[arg-type]
tiling_factor = factor_lowp // 2
break
elif call_ranges[tiling_indice] < factor_lowp:

View File

@ -648,7 +648,7 @@ def eq(left, right):
except TypeError: # unbacked symints
return False
if a == b:
V.graph.sizevars.guard_equals(left, right)
V.graph.sizevars.check_equals(left, right)
return a == b
@ -664,7 +664,7 @@ def lt(left, right):
return left != right
return False
if a < b:
V.graph.sizevars.guard_lt(left, right)
V.graph.sizevars.check_lt(left, right)
return a < b

View File

@ -1397,9 +1397,9 @@ class SIMDScheduling(BaseScheduling):
# Only install guards for 32-bit indexing as there is no correctness
# issue with using 64-bit for everything
V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type]
V.graph.sizevars.check_leq(numel, int_max) # type: ignore[arg-type]
for size in buf_sizes:
V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type]
V.graph.sizevars.check_leq(size, int_max) # type: ignore[arg-type]
return True
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):

View File

@ -81,7 +81,7 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
# no hint: we'll see if we know that this is a 32-bit int, and guard if possible.
int_max = torch.iinfo(torch.int32).max
if expr_fits_within_32bit(arg.expr):
V.graph.sizevars.guard_leq(arg.expr, int_max)
V.graph.sizevars.check_leq(arg.expr, int_max)
return "i32"
else:
return "i64"

View File

@ -3034,7 +3034,7 @@ class View(GenericView):
new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
break
V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
V.graph.sizevars.check_equals(sympy_product(old_size), sympy_product(new_size))
return old_size, new_size
@classmethod
@ -3091,14 +3091,14 @@ class View(GenericView):
stack_old.append(size_old) # re-add
elif size_hint(size_new) == size_hint(size_old):
view_expr.append(var)
V.graph.sizevars.guard_equals(size_new, size_old)
V.graph.sizevars.check_equals(size_new, size_old)
elif size_hint(size_new) < size_hint(size_old):
while size_hint(size_new) < size_hint(size_old):
var2, size_new2 = stack_new.pop()
var = var2 * size_new + var
size_new = size_new * size_new2
view_expr.append(var)
V.graph.sizevars.guard_equals(size_new, size_old)
V.graph.sizevars.check_equals(size_new, size_old)
elif size_hint(size_new) > size_hint(size_old):
divisor = sympy.S.One
modulus = size_old
@ -3109,18 +3109,18 @@ class View(GenericView):
view_expr.append(ModularIndexing(var, divisor, modulus))
divisor = divisor * modulus
size_old = size_old * modulus
V.graph.sizevars.guard_equals(size_new, size_old)
V.graph.sizevars.check_equals(size_new, size_old)
else:
raise AssertionError
while stack_old:
size_old = stack_old.pop()
V.graph.sizevars.guard_equals(size_old, 1)
V.graph.sizevars.check_equals(size_old, 1)
view_expr.append(sympy.S.Zero)
while stack_new:
var, size_new = stack_new.pop()
V.graph.sizevars.guard_equals(size_new, 1)
V.graph.sizevars.check_equals(size_new, 1)
if dense_dim is not None and len(new_size) == 1:
view_expr.reverse()
@ -3958,7 +3958,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
dtype=src.get_dtype(),
inner_fn=src.make_loader(),
ranges=[
V.graph.sizevars.guard_equals(a, b)
V.graph.sizevars.check_equals_and_simplify(a, b)
for a, b in zip(src.get_size(), dst.get_size())
],
).data
@ -4948,7 +4948,7 @@ class ConcatKernel(NopKernel):
if j == dim:
new_size[j] = new_size[j] + input_size[j]
else:
new_size[j] = V.graph.sizevars.guard_equals(
new_size[j] = V.graph.sizevars.check_equals_and_simplify(
new_size[j], input_size[j]
)
offsets_end.append(new_size[dim])
@ -5085,7 +5085,7 @@ class ConcatKernel(NopKernel):
dtype=src.get_dtype(),
inner_fn=src.make_loader(),
ranges=[
V.graph.sizevars.guard_equals(a, b)
V.graph.sizevars.check_equals_and_simplify(a, b)
for a, b in zip(src.get_size(), dst.get_size())
],
)
@ -8051,7 +8051,7 @@ class WhileLoop(ExternKernel):
rhs_exprs: Sequence[Union[int, Any]],
) -> None:
for lhs, rhs in zip(lhs_exprs, rhs_exprs):
V.graph.sizevars.guard_equals(lhs, rhs)
V.graph.sizevars.check_equals(lhs, rhs)
_guard_list_equals(op.get_size(), bo.get_size())
_guard_list_equals(op.get_stride(), bo.get_stride())

View File

@ -483,7 +483,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
)
query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
V.graph.sizevars.guard_leq(
V.graph.sizevars.check_leq(
seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
)

View File

@ -1058,8 +1058,8 @@ def tuned_sparse_semi_structured_mm(
m1, k1 = mat1.get_size()
m2, _ = mat1_meta.get_size()
k2, n = mat2.get_size()
m = V.graph.sizevars.guard_equals(m1, m2)
k = V.graph.sizevars.guard_equals(2 * k1, k2)
m = V.graph.sizevars.check_equals_and_simplify(m1, m2)
k = V.graph.sizevars.check_equals_and_simplify(2 * k1, k2)
if layout is None:
from torch._inductor.ir import FixedLayout

View File

@ -157,10 +157,10 @@ def mm_args(
*b2, n, k2 = mat2.get_size()
else:
*b2, k2, n = mat2.get_size()
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)]
if use_4x2_dim:
k2 = k2 * 2
k = V.graph.sizevars.guard_equals(k1, k2)
k = V.graph.sizevars.check_equals_and_simplify(k1, k2)
if layout is None:
from torch._inductor.ir import FixedLayout

View File

@ -611,28 +611,28 @@ def _tuned_grouped_mm_common(
m, k1 = m1_size
k2, _ = m2_size
g = offs.get_size()[0]
V.graph.sizevars.guard_equals(k1, k2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.guard_equals(g1, g2)
V.graph.sizevars.guard_equals(k1, k2)
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.guard_equals(g1, g2)
V.graph.sizevars.guard_equals(k1, k2)
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.guard_equals(g1, g2)
V.graph.sizevars.guard_equals(k1, k2)
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor")

View File

@ -494,7 +494,7 @@ def broadcast_symbolic_shapes(a, b):
):
output.append(y)
else:
V.graph.sizevars.guard_equals(x, y)
V.graph.sizevars.check_equals(x, y)
if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols):
output.append(y) # prefer shorter formula
else:
@ -1813,8 +1813,8 @@ def unfold(x, dimension, size, step):
dim_size = sizes[dim]
sizevars = V.graph.sizevars
sizevars.guard_leq(size, dim_size)
sizevars.guard_lt(0, step) # type: ignore[arg-type]
sizevars.check_leq(size, dim_size)
sizevars.check_lt(0, step) # type: ignore[arg-type]
new_dim_size = FloorDiv(dim_size - size, step) + 1
if sizevars.size_hint_or_throw(dim_size) > 0:
@ -2912,8 +2912,8 @@ def select_scatter(x, src, dim: int, index: int):
dim = _validate_dim(x, dim, 0)
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
index = index + x.get_size()[dim]
V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type]
V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type]
V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
src = expand(unsqueeze(src, dim), x.get_size())
src_loader = src.make_loader()
@ -4349,10 +4349,10 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
# Sliding windows must start within the input or left padding
x_alt -= 1 # type: ignore[assignment]
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
# ceil mode is actually a no-op, lets guard on that
V.graph.sizevars.guard_equals(x_out, x_alt)
V.graph.sizevars.check_equals(x_out, x_alt)
ceil_mode = False
else:
x_out = x_alt

View File

@ -306,7 +306,7 @@ class SizeVarAllocator:
# Note - [On Statically Known]
# The statically_known_* family of functions below NEVER guard, they could return True if the
# asked questions can be answered without guarding otherwise they return False.
# Those are similar to statically_known_true in symbolic_shapes but operate on sympy
# Those are similar to statically_known_true in symbolic_shapes.py but operate on sympy
# expressions instead of symnodes.
def statically_known_true(self, expr: Union[sympy.Basic, bool]) -> bool:
"""
@ -379,62 +379,56 @@ class SizeVarAllocator:
"""
return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr))
# The guard functions require you to ALREADY KNOW that a particular
# condition holds. If you don't know (you want to guard on an expression
# being a particular value, and then get access to that value), use
# the evaluate functions.
# The expect/check functions require you to ALREADY KNOW that a particular
# condition holds. They are similar to expect_true in symbolic_shapes.py and
# torch.check but operates on sympy expressions instead of symnodes.
def expect_true(self, expr: Expr) -> bool:
"""
Use it when you already know that expr is true or should be true and want to
ensure that guards/runtime assertions are in place to ensure this in compiled
function. Unlike check, this WON'T raise an error if expr isn't actually true.
check Note [expect_true].
"""
if not self.statically_known_true(expr):
return self.shape_env.guard_or_defer_runtime_assert(
expr, "sizevars.expect_true"
)
return True
def guard_equals(self, left: Expr, right: Expr) -> Expr:
if isinstance(left, Expr):
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
if isinstance(right, Expr):
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
def check(self, expr: Expr) -> None:
"""
Use it when you already know that expr is true or should be true and want to
ensure that guards/runtime assertions are in place to ensure this in compiled
function. Unlike expect_true, this WILL raise an error if expr isn't actually true.
check Note [expect_true].
"""
expr = sympy_subs(expr, self.inv_precomputed_replacements)
assert self.expect_true(expr)
expr = sympy.Eq(left, right)
static_expr = self.shape_env._maybe_evaluate_static(expr)
def check_equals(self, left: Expr, right: Expr) -> None:
"""
check(sympy.Eq(left, right)).
if static_expr is not None:
assert bool(static_expr)
return left
assert self.shape_env.guard_or_defer_runtime_assert(expr, "guard_equals")
"""
self.check(sympy.Eq(left, right))
return left
def guard_leq(self, left: Expr, right: Expr) -> None:
return self.guard_lt(left, right + 1)
def guard_lt(self, left: Expr, right: Expr) -> None:
expr = sympy.Lt(left, right)
static_expr = self.shape_env._maybe_evaluate_static(expr)
if static_expr is not None:
assert bool(static_expr)
return
assert self.shape_env.guard_or_defer_runtime_assert(expr, "guard_lt")
def guarded_order(self, seq):
def check_equals_and_simplify(self, left: Expr, right: Expr) -> Expr:
"""
Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
check(sympy.Eq(left, right)) and returns left after applying
inv_precomputed_replacements.
"""
seq = [*map(self.remove_precomputed_replacements, seq)]
seq = [
(self.size_hint_or_throw(var), orig_idx, var)
for orig_idx, var in enumerate(seq)
]
seq.sort()
order = [-1] * len(seq)
last_var = None
for new_index, (_, orig_index, var) in enumerate(seq):
order[orig_index] = new_index
if last_var is not None:
self.guard_leq(last_var, var)
last_var = var
return order
self.check(sympy.Eq(left, right))
return sympy_subs(left, self.inv_precomputed_replacements)
# Similar to the functions guard_or_false/guard_or_true in symbolic_shapes but operates on sympy
# expressions instead of symnodes. see Note [guard_or_].
def check_leq(self, left: Expr, right: Expr) -> None:
self.check(sympy.Le(left, right))
def check_lt(self, left: Expr, right: Expr) -> None:
self.check(sympy.Lt(left, right))
# Similar to the functions guard_or_false/guard_or_true in symbolic_shapes.py
# but operates on sympy expressions instead of symnodes. see Note [guard_or_].
def guard_or_false(self, left):
return self.evaluate_expr(left, fallback_value=False)
@ -485,10 +479,10 @@ class SizeVarAllocator:
f"evaluate_min({left}, {right}) with unbacked symints"
) from None
if lv <= rv:
self.guard_leq(left, right)
self.check_leq(left, right)
return left
else:
self.guard_leq(right, left)
self.check_leq(right, left)
return right
def evaluate_max(self, left: Expr, right: Expr) -> Expr:
@ -502,7 +496,7 @@ class SizeVarAllocator:
if isinstance(left, int):
return left
right = self.size_hint_or_throw(left)
self.guard_equals(left, sympy.Integer(right))
self.check_equals(left, sympy.Integer(right))
return int(right)
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]: