pytorch/torch/utils/_sympy/solve.py
Aaron Gokaslan 1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00

176 lines
6.2 KiB
Python

import logging
from typing import Dict, Optional, Tuple, Type
import sympy
from torch.utils._sympy.functions import FloorDiv
log = logging.getLogger(__name__)
_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {
sympy.Eq: sympy.Eq,
sympy.Ne: sympy.Ne,
sympy.Ge: sympy.Le,
sympy.Gt: sympy.Lt,
sympy.Le: sympy.Ge,
sympy.Lt: sympy.Gt,
}
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:
return _MIRROR_REL_OP.get(type, None)
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
#
# Returns a tuple of:
# 1. The simplified expression
# 2. The expression on the right-hand side
#
# Returns 'None' if it can't reach a state where the only thing in the left
# hand side is 'thing'.
#
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
# left-hand side.
#
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
# inequalities.
def try_solve(
expr: sympy.Basic,
thing: sympy.Basic,
trials: int = 5,
floordiv_inequality: bool = True,
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
mirror = mirror_rel_op(type(expr))
# Ignore unsupported expressions:
# - Those that are not relational operations
# - Those that don't have a mirror (just avoiding unexpected classes)
if not isinstance(expr, sympy.Rel) or mirror is None:
log.debug("expression with unsupported type: %s", type(expr))
return None
lhs_has_thing = expr.lhs.has(thing)
rhs_has_thing = expr.rhs.has(thing)
# Give up when 'thing' appears on both sides of the relational expression.
# That is because, as is, we assume the thing we are trying to isolate is
# only on the right-hand side.
if lhs_has_thing and rhs_has_thing:
log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
return None
# Try considering both LHS and RHS by mirroring the original expression:
# a < b ==> b > a
expressions = []
# Add each version of 'expr' if 'thing' is in its left-hand side.
if lhs_has_thing:
expressions.append(expr)
if rhs_has_thing:
expressions.append(mirror(expr.rhs, expr.lhs))
for e in expressions:
if e is None:
continue
assert isinstance(e, sympy.Rel)
for _ in range(trials):
trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
# Stop if there was no change in this trial.
if trial == e:
break
e = trial # type: ignore[assignment]
# Return if we were able to isolate 'thing' on the left-hand side.
if isinstance(e, sympy.Rel) and e.lhs == thing:
return e, e.rhs
return None
def _try_isolate_lhs(
expr: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
) -> sympy.Basic:
e = expr
op = type(expr)
if isinstance(e, sympy.Rel):
# Move any constants in the left-hand side to the right-hand side.
lhs_not_thing = (
sum(a for a in e.lhs.args if not a.has(thing))
if isinstance(e.lhs, sympy.Add)
else 0
)
e = op(expr.lhs - lhs_not_thing, expr.rhs - lhs_not_thing) # type: ignore[attr-defined]
# Divide both sides by the factors that don't contain thing.
if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
lhs, rhs = e.args
other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
# If we can't tell whether 'other' is negative or positive, we do nothing.
# That is because we don't know whether we have mirror the operation or not.
if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):
# Divide both sides by 'other'.
lhs = lhs / other
rhs = rhs / other
# If 'e' is an inequality and 'other' is negative, we have to
# mirror the expression.
if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
op = mirror_rel_op(op) # type: ignore[assignment]
assert op is not None
e = op(lhs, rhs)
################################################################################
# left-hand side is FloorDiv
################################################################################
#
# Given the expression: a // b op c
# where 'op' is a relational operation, these rules only work if:
# - b > 0
# - c is an integer
if (
floordiv_inequality
and isinstance(e, sympy.Rel)
and isinstance(e.lhs, FloorDiv)
and e.lhs.divisor.is_positive
and e.rhs.is_integer
):
# a // b == expr
# => a >= (b * expr) and a < (b * (expr + 1))
if isinstance(expr, sympy.Eq):
numerator, denominator = e.lhs.args
return sympy.And(
sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
)
# a // b != expr
# => a < (b * expr) or a >= (b * (expr + 1))
if isinstance(expr, sympy.Ne):
numerator, denominator = e.lhs.args
return sympy.Or(
sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
)
# The transformations below only work if b is positive.
# Note: we only have this information for constants.
# a // b > expr => a >= b * (expr + 1)
# a // b >= expr => a >= b * expr
if isinstance(expr, (sympy.Gt, sympy.Ge)):
quotient = e.rhs if isinstance(expr, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
# a // b < expr => a < b * expr
# a // b <= expr => a < b * (expr + 1)
if isinstance(expr, (sympy.Lt, sympy.Le)):
quotient = e.rhs if isinstance(expr, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
return e