Factor var_to_range assignments to _update_var_to_range helper (#124283)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124283
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #124105, #124059, #124176
This commit is contained in:
Edward Z. Yang 2024-04-21 04:20:05 -07:00 committed by PyTorch MergeBot
parent cbf420b67a
commit 0e6367dd44

View File

@ -258,8 +258,9 @@ def check_consistent(new, old) -> None:
# simplifies right now)
for i, j in zip(old.shape, new.shape):
torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
elif isinstance(new, scalar_types):
assert isinstance(old, scalar_types)
# NB: bool is subclass of int
elif isinstance(new, scalar_types) and not isinstance(new, bool):
assert isinstance(old, scalar_types) and not isinstance(old, bool), f"{old} != {new}"
torch._check(old == new, lambda: f"{old} != {new} (old != new)")
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
@ -515,8 +516,8 @@ def guard_scalar(a):
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int):
upd_vr = ValueRanges(compiler_min, compiler_max)
old_vr = shape_env.var_to_range.get(s, ValueRanges.unknown())
new_vr = shape_env.var_to_range[s] = old_vr & upd_vr
if new_vr != old_vr:
shape_env._update_var_to_range(s, upd_vr)
if (new_vr := shape_env.var_to_range[s]) != old_vr:
log.info("_constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper)
@ -3857,6 +3858,25 @@ class ShapeEnv:
# problem
)
def _update_var_to_range(self, symbol, vr):
lower, upper = vr.lower, vr.upper
# If we have a size-like unbacked SymInt, refuse to refine the range to be
# less than two. This is because when we intersect this range
# with [2, inf] for size oblivious tests, the range would be
# unsatisfiable. In other words, once you have a size-like
# unbacked SymInt, we can never learn that it is exactly zero or one,
# because we would now give inconsistent results for all size
# oblivous tests!
if upper < 2 and symbol in self.size_like:
upper = 2
# Updates the range and the guards corresponding to each bound of the symbol.
if symbol not in self.var_to_range:
self.var_to_range[symbol] = ValueRanges(lower, upper)
else:
self.var_to_range[symbol] &= ValueRanges(lower, upper)
def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None:
"""
Adds or updates a replacement for a symbol.
@ -3889,7 +3909,7 @@ class ShapeEnv:
# substitution in the end. This might be a no-op, if a already has
# a tighter bound
tgt_bound = self.bound_sympy(tgt)
self.var_to_range[a] = src_bound & tgt_bound
self._update_var_to_range(a, tgt_bound)
# Next, check if we can update the range of free symbols in tgt
# based on the range in a. But only do it if:
@ -4094,9 +4114,9 @@ class ShapeEnv:
# Propagate the value ranges. It doesn't really
# matter if we use truediv or floordiv, because we
# have established divisibility.
self.var_to_range[i1] = SymPyValueRangeAnalysis.truediv(
self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv(
self.var_to_range[i0], ValueRanges.wrap(d)
)
))
# Propagate size-like-ness
if i0 in self.size_like:
self.size_like.add(i1)
@ -4523,7 +4543,7 @@ class ShapeEnv:
continue
# Updates the range and the guards corresponding to each bound of the symbol.
self.var_to_range[symbol] = ValueRanges(lower, upper)
self._update_var_to_range(symbol, ValueRanges(lower, upper))
# Clears the cache, since this update can change the result.
self._maybe_evaluate_static.cache_clear()