Value range refinement using uni-variate expressions. (#97963)

This PR introduces value range refinement of shape symbols by symbolically evaluating the
value range of the involved guards. This should help `_maybe_evaluate_static` to eliminate
more guards.

This is a stack of PRs created from the discussion on: #96616.

In summary, this PR:
- simplifies `FloorDiv` nodes on the left-hand side of an expression so as to isolate a
symbol in the numerator
- tries to match the expression against the form: `<symbol> <relop> <expr>`
- uses the matched expression for refining the value range of `<symbol>` using the range
of `<expr>`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97963
Approved by: https://github.com/ezyang
This commit is contained in:
Yukio Siraichi 2023-06-29 11:39:48 -03:00 committed by PyTorch MergeBot
parent e311bed2a8
commit ffb526a2e4
2 changed files with 214 additions and 13 deletions

View File

@ -1347,6 +1347,20 @@ def forward(self, a_1):
"""["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950
)
def test_guard_upperbound_range_refinement(self):
def f(a):
assert a.shape[0] > 5 and a.shape[0] > 12
return a.cos()
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] > 12""")
def test_guard_lowerbound_range_refinement(self):
def f(a):
assert a.shape[0] < 20 and a.shape[0] < 30
return a.cos()
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] < 20""")
def test_sym_storage_offset(self):
def f(x, y):
return x + y

View File

@ -1482,6 +1482,7 @@ def _lru_cache(fn, maxsize=None):
fn_cache.cache_clear()
return fn_cache(self, *args, **kwargs)
wrapper.cache_clear = fn_cache.cache_clear
wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined]
return wrapper
@ -1988,6 +1989,10 @@ class ShapeEnv:
self.var_to_range: Dict["sympy.Symbol", ValueRanges] = {}
self.var_to_sources: Dict["sympy.Symbol", List[Source]] = {}
self.var_to_stack: Dict["sympy.Symbol", traceback.StackSummary] = {}
# Maps symbolic ints to the guards that refine their lower/upper
# bound. If one of them is None, it means that there are no guards
# that refine that respective bound.
self.var_to_guards: Dict["sympy.Symbol", Tuple[Optional[ShapeGuard], Optional[ShapeGuard]]] = {}
# Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
@ -2711,20 +2716,27 @@ Target Guards:
# 2. Every guard must evaluate to True (but remember many guards
# like s0 == s1*2 because trivial due to simplification)
for g, tb in self.guards:
if self._maybe_evaluate_static(g) is not None:
continue
g = self.simplify(g)
issued = set()
def issue_guard(guard: ShapeGuard) -> None:
expr = self.simplify(guard.expr)
# Avoid re-issueing the same guard.
if guard.expr in issued:
return
issued.add(expr)
try:
if any(is_dim(source) for s in g.free_symbols for source in symbol_to_source[s]):
self.dim_constraints.add(g)
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(g)
if any(is_dim(source) for s in expr.free_symbols for source in symbol_to_source[s]):
self.dim_constraints.add(expr)
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
exprs.append(guard_expr)
self._add_target_expr(g)
self._add_target_expr(expr)
# A non-relational constraint on a single sizevar can violate
# a constraint
if len(g.free_symbols) == 1:
symbol = list(g.free_symbols)[0]
if len(expr.free_symbols) == 1:
symbol = list(expr.free_symbols)[0]
source = symbol_to_source[symbol][0]
constraints = symbol_to_constraints[symbol]
for c in constraints:
@ -2743,9 +2755,26 @@ Target Guards:
else:
raise AssertionError(f"unrecognized constraint {c}")
except Exception:
self.log.warning("Failing guard allocated at: \n%s", tb)
self.log.warning("Failing guard allocated at: \n%s", guard.stack)
raise
# First, issue all the non-trivial guards.
for guard in self.guards:
if self._maybe_evaluate_static(guard.expr) is not None:
continue
issue_guard(guard)
# Then, issue the guards that refine the value range of tracked symbols.
# We need to explicitly issue these guards, since they are the ones that
# guarantee the symbol's value range. Plus, due to the updated value
# range, they may be skipped in the previous step.
for symbol, guards in self.var_to_guards.items():
if symbol not in symbol_to_source:
continue
for guard in guards:
if guard is not None:
issue_guard(guard)
# 3. Every symbol must be within its value range (this handles 0/1
# specialization too). NB: because we never update value ranges
# except in case of explicit user annotation, these are not included
@ -2774,8 +2803,9 @@ Target Guards:
assert sources
assert symbol.is_integer
g_lower, g_upper = self.var_to_guards.get(symbol, (None, None))
bounds = []
if r.lower != -sympy.oo:
if r.lower != -sympy.oo and g_lower is None:
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
bounds.append(str(r.lower))
@ -2787,7 +2817,7 @@ Target Guards:
# Note that you can be off by a pretty large constant and it
# won't matter because sizes in practice will be no where near
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
if r.upper != sympy.oo and r.upper < sys.maxsize - 1 and g_upper is None:
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Le(symbol, r.upper))
bounds.append(str(r.upper))
@ -2986,6 +3016,62 @@ Target Guards:
self.divisible = new_divisible
@_lru_cache
def try_isolate_symbol_lhs(self, expr: "sympy.Expr") -> "sympy.Expr":
def get_added_const(expr):
"""
Returns an integer constant being added at the top-level of this expression.
"""
if isinstance(expr, sympy.Add):
for a in expr.args:
if isinstance(a, sympy.Integer):
return a
return None
# Move any constants in the left-hand side to the right-hand side.
if isinstance(expr, sympy.Rel):
lhs_const = get_added_const(expr.lhs)
if lhs_const is not None:
expr = type(expr)(expr.lhs - lhs_const, expr.rhs - lhs_const) # type: ignore[arg-type]
# a // b == expr
# => a >= (b * expr) and a < ((b + 1) * expr)
if isinstance(expr, sympy.Eq) and isinstance(expr.lhs, FloorDiv):
numerator, denominator = expr.lhs.args
expr = sympy.And(
sympy.Ge(numerator, (expr.rhs * denominator)), # type: ignore[arg-type]
sympy.Lt(numerator, ((expr.rhs + 1) * denominator)) # type: ignore[arg-type]
)
# a // b != expr
# => a < (b * expr) or a >= ((b + 1) * expr)
if isinstance(expr, sympy.Ne) and isinstance(expr.lhs, FloorDiv):
numerator, denominator = expr.lhs.args
expr = sympy.Or(
sympy.Lt(numerator, (expr.rhs * denominator)), # type: ignore[arg-type]
sympy.Ge(numerator, ((expr.rhs + 1) * denominator)) # type: ignore[arg-type]
)
# The transformations below only work if b is positive.
# Note: we only have this information for constants.
def is_floordiv_with_positive_denominator(e) -> bool:
if not isinstance(e, FloorDiv):
return False
number = e.args[1]
return isinstance(number, sympy.Integer) and bool(number > 0)
# a // b > expr => a >= (b + 1) * expr
# a // b >= expr => a >= b * expr
if isinstance(expr, (sympy.Gt, sympy.Ge)) and is_floordiv_with_positive_denominator(expr.lhs):
quotient = expr.rhs if isinstance(expr, sympy.Ge) else (expr.rhs + 1) # type: ignore[arg-type]
expr = sympy.Ge(expr.lhs.args[0], (quotient * expr.lhs.args[1])) # type: ignore[arg-type]
# a // b < expr => a < b * expr
# a // b <= expr => a < (b + 1) * expr
if isinstance(expr, (sympy.Lt, sympy.Le)) and is_floordiv_with_positive_denominator(expr.lhs):
quotient = expr.rhs if isinstance(expr, sympy.Lt) else (expr.rhs + 1) # type: ignore[arg-type]
expr = sympy.Lt(expr.lhs.args[0], (quotient * expr.lhs.args[1])) # type: ignore[arg-type]
return expr
@_lru_cache
def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
expr = self.replace(expr)
@ -3256,6 +3342,7 @@ Target Guards:
stack = ''.join(traceback.format_list(tb))
guard = ShapeGuard(g, stack)
self.guards.append(guard)
self.refine_ranges(guard)
if self.log.isEnabledFor(logging.INFO):
for frame in reversed(tb):
if frame.filename not in uninteresting_files():
@ -3289,6 +3376,106 @@ Target Guards:
return concrete_val
# Refines the ranges of the variables present in 'guard'.
#
# This function tries to refine the range of the variables inside
# 'guard' by reasoning about it. Specifically, when 'guard' is a
# 'sympy.Relational' operation.
#
# It does mainly 3 things:
# 1. Tries to isolate a variable in the left-hand side
# 2. Compute the value range of the right-hand side
# 3. Update the value range of the variable, if better
def refine_ranges(self, guard: ShapeGuard) -> None:
def simplify(expr: sympy.Expr) -> sympy.Expr:
"""
Simplification specialized for range refinement.
"""
return self.try_isolate_symbol_lhs(self.simplify(expr))
def simplify_until(expr: sympy.Expr, max_iterations: int = 10) -> sympy.Expr:
"""
Calls 'simplify' either until it does not change or until it reaches the
maximum number of iterations.
"""
for _ in range(max_iterations):
previous, expr = expr, simplify(expr)
if expr == previous:
break
return expr
RELOP_MIRROR = {
sympy.Ge: sympy.Le,
sympy.Gt: sympy.Lt,
sympy.Le: sympy.Ge,
sympy.Lt: sympy.Gt,
}
# List of expressions to be processed.
# Here, we try considering both LHS and RHS by mirroring the
# original expression: a < b ==> b > a
exprs = [guard.expr]
if type(guard.expr) in RELOP_MIRROR:
exprs.append(RELOP_MIRROR[type(guard.expr)](guard.expr.rhs, guard.expr.lhs)) # type: ignore[arg-type]
for expr in exprs:
# First, try to simplify the left-hand side.
expr = simplify_until(expr)
# Filter the guards that are not:
# 1. are relational operations
# 2. have a symbol as the left-hand side
# 3. already have a range
if not (
isinstance(expr, sympy.Rel)
and isinstance(expr.lhs, sympy.Symbol)
and expr.lhs in self.var_to_range
):
continue
# Use only univariate functions.
if len(expr.rhs.free_symbols) > 0:
continue
# Update the value range of the left-hand side, if the
# right-hand side provides a better range.
symbol = expr.lhs
vr = self.var_to_range[symbol]
lower, upper = vr.lower, vr.upper
rhs_vr = sympy_interp(ValueRangeAnalysis, self.var_to_range, expr.rhs) # type: ignore[arg-type]
lower_guard, upper_guard = self.var_to_guards.get(symbol, (None, None))
# Let's suppose that we have a preexisting range for x [0, 100].
# Now, we issue a guard x > y, where the range for y is [50, 150].
# Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen,
# refining x to [51, 100], since x must be greater than y, but the lowest
# y could be is 50.
#
# sympy.Eq may update both lower and upper bounds.
# sympy.G{t,e} may update the lower bound, only.
# sympy.L{t,e} may update the upper bound, only.
if lower < rhs_vr.lower and isinstance(expr, (sympy.Eq, sympy.Ge, sympy.Gt)):
# Strictly greater relations allow us to refine a bit more, since
# x < y implies that the lower bound for x is: y + 1.
lower = rhs_vr.lower + int(isinstance(expr, sympy.Gt))
lower_guard = guard
if upper > rhs_vr.upper and isinstance(expr, (sympy.Eq, sympy.Le, sympy.Lt)):
upper = rhs_vr.upper - int(isinstance(expr, sympy.Lt))
upper_guard = guard
# Do nothing if the new value range is no better than what we already have.
if vr == ValueRanges(lower, upper):
continue
# Updates the range and the guards corresponding to each bound of the symbol.
self.var_to_range[symbol] = ValueRanges(lower, upper)
self.var_to_guards[symbol] = (lower_guard, upper_guard)
# Clears the cache, since this update can change the result.
self._maybe_evaluate_static.cache_clear()
def _is_int(expr):
if not isinstance(expr, SymInt):
return False