mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Factor var_to_range assignments to _update_var_to_range helper (#124283)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124283 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #124105, #124059, #124176
This commit is contained in:
parent
cbf420b67a
commit
0e6367dd44
|
|
@ -258,8 +258,9 @@ def check_consistent(new, old) -> None:
|
|||
# simplifies right now)
|
||||
for i, j in zip(old.shape, new.shape):
|
||||
torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
|
||||
elif isinstance(new, scalar_types):
|
||||
assert isinstance(old, scalar_types)
|
||||
# NB: bool is subclass of int
|
||||
elif isinstance(new, scalar_types) and not isinstance(new, bool):
|
||||
assert isinstance(old, scalar_types) and not isinstance(old, bool), f"{old} != {new}"
|
||||
torch._check(old == new, lambda: f"{old} != {new} (old != new)")
|
||||
|
||||
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
|
||||
|
|
@ -515,8 +516,8 @@ def guard_scalar(a):
|
|||
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int):
|
||||
upd_vr = ValueRanges(compiler_min, compiler_max)
|
||||
old_vr = shape_env.var_to_range.get(s, ValueRanges.unknown())
|
||||
new_vr = shape_env.var_to_range[s] = old_vr & upd_vr
|
||||
if new_vr != old_vr:
|
||||
shape_env._update_var_to_range(s, upd_vr)
|
||||
if (new_vr := shape_env.var_to_range[s]) != old_vr:
|
||||
log.info("_constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper)
|
||||
|
||||
|
||||
|
|
@ -3857,6 +3858,25 @@ class ShapeEnv:
|
|||
# problem
|
||||
)
|
||||
|
||||
def _update_var_to_range(self, symbol, vr):
|
||||
lower, upper = vr.lower, vr.upper
|
||||
|
||||
# If we have a size-like unbacked SymInt, refuse to refine the range to be
|
||||
# less than two. This is because when we intersect this range
|
||||
# with [2, inf] for size oblivious tests, the range would be
|
||||
# unsatisfiable. In other words, once you have a size-like
|
||||
# unbacked SymInt, we can never learn that it is exactly zero or one,
|
||||
# because we would now give inconsistent results for all size
|
||||
# oblivous tests!
|
||||
if upper < 2 and symbol in self.size_like:
|
||||
upper = 2
|
||||
|
||||
# Updates the range and the guards corresponding to each bound of the symbol.
|
||||
if symbol not in self.var_to_range:
|
||||
self.var_to_range[symbol] = ValueRanges(lower, upper)
|
||||
else:
|
||||
self.var_to_range[symbol] &= ValueRanges(lower, upper)
|
||||
|
||||
def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None:
|
||||
"""
|
||||
Adds or updates a replacement for a symbol.
|
||||
|
|
@ -3889,7 +3909,7 @@ class ShapeEnv:
|
|||
# substitution in the end. This might be a no-op, if a already has
|
||||
# a tighter bound
|
||||
tgt_bound = self.bound_sympy(tgt)
|
||||
self.var_to_range[a] = src_bound & tgt_bound
|
||||
self._update_var_to_range(a, tgt_bound)
|
||||
|
||||
# Next, check if we can update the range of free symbols in tgt
|
||||
# based on the range in a. But only do it if:
|
||||
|
|
@ -4094,9 +4114,9 @@ class ShapeEnv:
|
|||
# Propagate the value ranges. It doesn't really
|
||||
# matter if we use truediv or floordiv, because we
|
||||
# have established divisibility.
|
||||
self.var_to_range[i1] = SymPyValueRangeAnalysis.truediv(
|
||||
self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv(
|
||||
self.var_to_range[i0], ValueRanges.wrap(d)
|
||||
)
|
||||
))
|
||||
# Propagate size-like-ness
|
||||
if i0 in self.size_like:
|
||||
self.size_like.add(i1)
|
||||
|
|
@ -4523,7 +4543,7 @@ class ShapeEnv:
|
|||
continue
|
||||
|
||||
# Updates the range and the guards corresponding to each bound of the symbol.
|
||||
self.var_to_range[symbol] = ValueRanges(lower, upper)
|
||||
self._update_var_to_range(symbol, ValueRanges(lower, upper))
|
||||
# Clears the cache, since this update can change the result.
|
||||
self._maybe_evaluate_static.cache_clear()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user