mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Unify dynamic shapes APIs naming 2 (expect_true and check) attempt2 (#156518)
Summary: The functions guard_lt, guard_equals, and guard_leq work similarly to torch.check and expect_true, but they operate on SymPy expressions. Notably, guard_equals applies local replacements before comparison, which might be better extracted into a separate function. This pull request standardizes naming conventions to match symbolic_shapes.py. Specifically, - it introduces size_vars.expect_true and size_vars.check. - guard_lt becomes check_lt - guard_leq becomes check_leq - guard_equals becomes check_equals I am also seeing a couple of wrong usages !! that i will fix in the next PR Test Plan: OSS and cont Rollback Plan: Differential Revision: D77054177 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156518 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
dfef1e4408
commit
26f7ca3972
|
|
@ -3735,7 +3735,7 @@ class TilingSelect:
|
|||
call_ranges[tiling_indice], fallback=0
|
||||
)
|
||||
if call_range < factor_lowp:
|
||||
V.graph.sizevars.guard_lt(call_range, factor_lowp) # type: ignore[arg-type]
|
||||
V.graph.sizevars.check_lt(call_range, factor_lowp) # type: ignore[arg-type]
|
||||
tiling_factor = factor_lowp // 2
|
||||
break
|
||||
elif call_ranges[tiling_indice] < factor_lowp:
|
||||
|
|
|
|||
|
|
@ -648,7 +648,7 @@ def eq(left, right):
|
|||
except TypeError: # unbacked symints
|
||||
return False
|
||||
if a == b:
|
||||
V.graph.sizevars.guard_equals(left, right)
|
||||
V.graph.sizevars.check_equals(left, right)
|
||||
return a == b
|
||||
|
||||
|
||||
|
|
@ -664,7 +664,7 @@ def lt(left, right):
|
|||
return left != right
|
||||
return False
|
||||
if a < b:
|
||||
V.graph.sizevars.guard_lt(left, right)
|
||||
V.graph.sizevars.check_lt(left, right)
|
||||
return a < b
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1397,9 +1397,9 @@ class SIMDScheduling(BaseScheduling):
|
|||
|
||||
# Only install guards for 32-bit indexing as there is no correctness
|
||||
# issue with using 64-bit for everything
|
||||
V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type]
|
||||
V.graph.sizevars.check_leq(numel, int_max) # type: ignore[arg-type]
|
||||
for size in buf_sizes:
|
||||
V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type]
|
||||
V.graph.sizevars.check_leq(size, int_max) # type: ignore[arg-type]
|
||||
return True
|
||||
|
||||
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
|
|||
# no hint: we'll see if we know that this is a 32-bit int, and guard if possible.
|
||||
int_max = torch.iinfo(torch.int32).max
|
||||
if expr_fits_within_32bit(arg.expr):
|
||||
V.graph.sizevars.guard_leq(arg.expr, int_max)
|
||||
V.graph.sizevars.check_leq(arg.expr, int_max)
|
||||
return "i32"
|
||||
else:
|
||||
return "i64"
|
||||
|
|
|
|||
|
|
@ -3034,7 +3034,7 @@ class View(GenericView):
|
|||
new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
|
||||
break
|
||||
|
||||
V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
|
||||
V.graph.sizevars.check_equals(sympy_product(old_size), sympy_product(new_size))
|
||||
return old_size, new_size
|
||||
|
||||
@classmethod
|
||||
|
|
@ -3091,14 +3091,14 @@ class View(GenericView):
|
|||
stack_old.append(size_old) # re-add
|
||||
elif size_hint(size_new) == size_hint(size_old):
|
||||
view_expr.append(var)
|
||||
V.graph.sizevars.guard_equals(size_new, size_old)
|
||||
V.graph.sizevars.check_equals(size_new, size_old)
|
||||
elif size_hint(size_new) < size_hint(size_old):
|
||||
while size_hint(size_new) < size_hint(size_old):
|
||||
var2, size_new2 = stack_new.pop()
|
||||
var = var2 * size_new + var
|
||||
size_new = size_new * size_new2
|
||||
view_expr.append(var)
|
||||
V.graph.sizevars.guard_equals(size_new, size_old)
|
||||
V.graph.sizevars.check_equals(size_new, size_old)
|
||||
elif size_hint(size_new) > size_hint(size_old):
|
||||
divisor = sympy.S.One
|
||||
modulus = size_old
|
||||
|
|
@ -3109,18 +3109,18 @@ class View(GenericView):
|
|||
view_expr.append(ModularIndexing(var, divisor, modulus))
|
||||
divisor = divisor * modulus
|
||||
size_old = size_old * modulus
|
||||
V.graph.sizevars.guard_equals(size_new, size_old)
|
||||
V.graph.sizevars.check_equals(size_new, size_old)
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
while stack_old:
|
||||
size_old = stack_old.pop()
|
||||
V.graph.sizevars.guard_equals(size_old, 1)
|
||||
V.graph.sizevars.check_equals(size_old, 1)
|
||||
view_expr.append(sympy.S.Zero)
|
||||
|
||||
while stack_new:
|
||||
var, size_new = stack_new.pop()
|
||||
V.graph.sizevars.guard_equals(size_new, 1)
|
||||
V.graph.sizevars.check_equals(size_new, 1)
|
||||
|
||||
if dense_dim is not None and len(new_size) == 1:
|
||||
view_expr.reverse()
|
||||
|
|
@ -3958,7 +3958,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
|
|||
dtype=src.get_dtype(),
|
||||
inner_fn=src.make_loader(),
|
||||
ranges=[
|
||||
V.graph.sizevars.guard_equals(a, b)
|
||||
V.graph.sizevars.check_equals_and_simplify(a, b)
|
||||
for a, b in zip(src.get_size(), dst.get_size())
|
||||
],
|
||||
).data
|
||||
|
|
@ -4948,7 +4948,7 @@ class ConcatKernel(NopKernel):
|
|||
if j == dim:
|
||||
new_size[j] = new_size[j] + input_size[j]
|
||||
else:
|
||||
new_size[j] = V.graph.sizevars.guard_equals(
|
||||
new_size[j] = V.graph.sizevars.check_equals_and_simplify(
|
||||
new_size[j], input_size[j]
|
||||
)
|
||||
offsets_end.append(new_size[dim])
|
||||
|
|
@ -5085,7 +5085,7 @@ class ConcatKernel(NopKernel):
|
|||
dtype=src.get_dtype(),
|
||||
inner_fn=src.make_loader(),
|
||||
ranges=[
|
||||
V.graph.sizevars.guard_equals(a, b)
|
||||
V.graph.sizevars.check_equals_and_simplify(a, b)
|
||||
for a, b in zip(src.get_size(), dst.get_size())
|
||||
],
|
||||
)
|
||||
|
|
@ -8051,7 +8051,7 @@ class WhileLoop(ExternKernel):
|
|||
rhs_exprs: Sequence[Union[int, Any]],
|
||||
) -> None:
|
||||
for lhs, rhs in zip(lhs_exprs, rhs_exprs):
|
||||
V.graph.sizevars.guard_equals(lhs, rhs)
|
||||
V.graph.sizevars.check_equals(lhs, rhs)
|
||||
|
||||
_guard_list_equals(op.get_size(), bo.get_size())
|
||||
_guard_list_equals(op.get_stride(), bo.get_stride())
|
||||
|
|
|
|||
|
|
@ -483,7 +483,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
)
|
||||
query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
|
||||
|
||||
V.graph.sizevars.guard_leq(
|
||||
V.graph.sizevars.check_leq(
|
||||
seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1058,8 +1058,8 @@ def tuned_sparse_semi_structured_mm(
|
|||
m1, k1 = mat1.get_size()
|
||||
m2, _ = mat1_meta.get_size()
|
||||
k2, n = mat2.get_size()
|
||||
m = V.graph.sizevars.guard_equals(m1, m2)
|
||||
k = V.graph.sizevars.guard_equals(2 * k1, k2)
|
||||
m = V.graph.sizevars.check_equals_and_simplify(m1, m2)
|
||||
k = V.graph.sizevars.check_equals_and_simplify(2 * k1, k2)
|
||||
|
||||
if layout is None:
|
||||
from torch._inductor.ir import FixedLayout
|
||||
|
|
|
|||
|
|
@ -157,10 +157,10 @@ def mm_args(
|
|||
*b2, n, k2 = mat2.get_size()
|
||||
else:
|
||||
*b2, k2, n = mat2.get_size()
|
||||
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
|
||||
b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)]
|
||||
if use_4x2_dim:
|
||||
k2 = k2 * 2
|
||||
k = V.graph.sizevars.guard_equals(k1, k2)
|
||||
k = V.graph.sizevars.check_equals_and_simplify(k1, k2)
|
||||
if layout is None:
|
||||
from torch._inductor.ir import FixedLayout
|
||||
|
||||
|
|
|
|||
|
|
@ -611,28 +611,28 @@ def _tuned_grouped_mm_common(
|
|||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.guard_equals(g1, g2)
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.guard_equals(g1, g2)
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.guard_equals(g1, g2)
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor")
|
||||
|
|
|
|||
|
|
@ -494,7 +494,7 @@ def broadcast_symbolic_shapes(a, b):
|
|||
):
|
||||
output.append(y)
|
||||
else:
|
||||
V.graph.sizevars.guard_equals(x, y)
|
||||
V.graph.sizevars.check_equals(x, y)
|
||||
if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols):
|
||||
output.append(y) # prefer shorter formula
|
||||
else:
|
||||
|
|
@ -1813,8 +1813,8 @@ def unfold(x, dimension, size, step):
|
|||
|
||||
dim_size = sizes[dim]
|
||||
sizevars = V.graph.sizevars
|
||||
sizevars.guard_leq(size, dim_size)
|
||||
sizevars.guard_lt(0, step) # type: ignore[arg-type]
|
||||
sizevars.check_leq(size, dim_size)
|
||||
sizevars.check_lt(0, step) # type: ignore[arg-type]
|
||||
|
||||
new_dim_size = FloorDiv(dim_size - size, step) + 1
|
||||
if sizevars.size_hint_or_throw(dim_size) > 0:
|
||||
|
|
@ -2912,8 +2912,8 @@ def select_scatter(x, src, dim: int, index: int):
|
|||
dim = _validate_dim(x, dim, 0)
|
||||
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
|
||||
index = index + x.get_size()[dim]
|
||||
V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type]
|
||||
V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
|
||||
V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type]
|
||||
V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
|
||||
src = expand(unsqueeze(src, dim), x.get_size())
|
||||
src_loader = src.make_loader()
|
||||
|
||||
|
|
@ -4349,10 +4349,10 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None
|
|||
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
|
||||
# Sliding windows must start within the input or left padding
|
||||
x_alt -= 1 # type: ignore[assignment]
|
||||
V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
|
||||
V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
|
||||
if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
|
||||
# ceil mode is actually a no-op, lets guard on that
|
||||
V.graph.sizevars.guard_equals(x_out, x_alt)
|
||||
V.graph.sizevars.check_equals(x_out, x_alt)
|
||||
ceil_mode = False
|
||||
else:
|
||||
x_out = x_alt
|
||||
|
|
|
|||
|
|
@ -306,7 +306,7 @@ class SizeVarAllocator:
|
|||
# Note - [On Statically Known]
|
||||
# The statically_known_* family of functions below NEVER guard, they could return True if the
|
||||
# asked questions can be answered without guarding otherwise they return False.
|
||||
# Those are similar to statically_known_true in symbolic_shapes but operate on sympy
|
||||
# Those are similar to statically_known_true in symbolic_shapes.py but operate on sympy
|
||||
# expressions instead of symnodes.
|
||||
def statically_known_true(self, expr: Union[sympy.Basic, bool]) -> bool:
|
||||
"""
|
||||
|
|
@ -379,62 +379,56 @@ class SizeVarAllocator:
|
|||
"""
|
||||
return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr))
|
||||
|
||||
# The guard functions require you to ALREADY KNOW that a particular
|
||||
# condition holds. If you don't know (you want to guard on an expression
|
||||
# being a particular value, and then get access to that value), use
|
||||
# the evaluate functions.
|
||||
# The expect/check functions require you to ALREADY KNOW that a particular
|
||||
# condition holds. They are similar to expect_true in symbolic_shapes.py and
|
||||
# torch.check but operates on sympy expressions instead of symnodes.
|
||||
def expect_true(self, expr: Expr) -> bool:
|
||||
"""
|
||||
Use it when you already know that expr is true or should be true and want to
|
||||
ensure that guards/runtime assertions are in place to ensure this in compiled
|
||||
function. Unlike check, this WON'T raise an error if expr isn't actually true.
|
||||
check Note [expect_true].
|
||||
"""
|
||||
if not self.statically_known_true(expr):
|
||||
return self.shape_env.guard_or_defer_runtime_assert(
|
||||
expr, "sizevars.expect_true"
|
||||
)
|
||||
return True
|
||||
|
||||
def guard_equals(self, left: Expr, right: Expr) -> Expr:
|
||||
if isinstance(left, Expr):
|
||||
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
if isinstance(right, Expr):
|
||||
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
def check(self, expr: Expr) -> None:
|
||||
"""
|
||||
Use it when you already know that expr is true or should be true and want to
|
||||
ensure that guards/runtime assertions are in place to ensure this in compiled
|
||||
function. Unlike expect_true, this WILL raise an error if expr isn't actually true.
|
||||
check Note [expect_true].
|
||||
"""
|
||||
expr = sympy_subs(expr, self.inv_precomputed_replacements)
|
||||
assert self.expect_true(expr)
|
||||
|
||||
expr = sympy.Eq(left, right)
|
||||
static_expr = self.shape_env._maybe_evaluate_static(expr)
|
||||
def check_equals(self, left: Expr, right: Expr) -> None:
|
||||
"""
|
||||
check(sympy.Eq(left, right)).
|
||||
|
||||
if static_expr is not None:
|
||||
assert bool(static_expr)
|
||||
return left
|
||||
|
||||
assert self.shape_env.guard_or_defer_runtime_assert(expr, "guard_equals")
|
||||
"""
|
||||
self.check(sympy.Eq(left, right))
|
||||
return left
|
||||
|
||||
def guard_leq(self, left: Expr, right: Expr) -> None:
|
||||
return self.guard_lt(left, right + 1)
|
||||
|
||||
def guard_lt(self, left: Expr, right: Expr) -> None:
|
||||
expr = sympy.Lt(left, right)
|
||||
static_expr = self.shape_env._maybe_evaluate_static(expr)
|
||||
|
||||
if static_expr is not None:
|
||||
assert bool(static_expr)
|
||||
return
|
||||
|
||||
assert self.shape_env.guard_or_defer_runtime_assert(expr, "guard_lt")
|
||||
|
||||
def guarded_order(self, seq):
|
||||
def check_equals_and_simplify(self, left: Expr, right: Expr) -> Expr:
|
||||
"""
|
||||
Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
|
||||
check(sympy.Eq(left, right)) and returns left after applying
|
||||
inv_precomputed_replacements.
|
||||
"""
|
||||
seq = [*map(self.remove_precomputed_replacements, seq)]
|
||||
seq = [
|
||||
(self.size_hint_or_throw(var), orig_idx, var)
|
||||
for orig_idx, var in enumerate(seq)
|
||||
]
|
||||
seq.sort()
|
||||
order = [-1] * len(seq)
|
||||
last_var = None
|
||||
for new_index, (_, orig_index, var) in enumerate(seq):
|
||||
order[orig_index] = new_index
|
||||
if last_var is not None:
|
||||
self.guard_leq(last_var, var)
|
||||
last_var = var
|
||||
return order
|
||||
self.check(sympy.Eq(left, right))
|
||||
return sympy_subs(left, self.inv_precomputed_replacements)
|
||||
|
||||
# Similar to the functions guard_or_false/guard_or_true in symbolic_shapes but operates on sympy
|
||||
# expressions instead of symnodes. see Note [guard_or_].
|
||||
def check_leq(self, left: Expr, right: Expr) -> None:
|
||||
self.check(sympy.Le(left, right))
|
||||
|
||||
def check_lt(self, left: Expr, right: Expr) -> None:
|
||||
self.check(sympy.Lt(left, right))
|
||||
|
||||
# Similar to the functions guard_or_false/guard_or_true in symbolic_shapes.py
|
||||
# but operates on sympy expressions instead of symnodes. see Note [guard_or_].
|
||||
def guard_or_false(self, left):
|
||||
return self.evaluate_expr(left, fallback_value=False)
|
||||
|
||||
|
|
@ -485,10 +479,10 @@ class SizeVarAllocator:
|
|||
f"evaluate_min({left}, {right}) with unbacked symints"
|
||||
) from None
|
||||
if lv <= rv:
|
||||
self.guard_leq(left, right)
|
||||
self.check_leq(left, right)
|
||||
return left
|
||||
else:
|
||||
self.guard_leq(right, left)
|
||||
self.check_leq(right, left)
|
||||
return right
|
||||
|
||||
def evaluate_max(self, left: Expr, right: Expr) -> Expr:
|
||||
|
|
@ -502,7 +496,7 @@ class SizeVarAllocator:
|
|||
if isinstance(left, int):
|
||||
return left
|
||||
right = self.size_hint_or_throw(left)
|
||||
self.guard_equals(left, sympy.Integer(right))
|
||||
self.check_equals(left, sympy.Integer(right))
|
||||
return int(right)
|
||||
|
||||
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user