mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied. Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960 Approved by: https://github.com/malfet
176 lines
6.2 KiB
Python
176 lines
6.2 KiB
Python
import logging
|
|
|
|
from typing import Dict, Optional, Tuple, Type
|
|
|
|
import sympy
|
|
|
|
from torch.utils._sympy.functions import FloorDiv
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {
|
|
sympy.Eq: sympy.Eq,
|
|
sympy.Ne: sympy.Ne,
|
|
sympy.Ge: sympy.Le,
|
|
sympy.Gt: sympy.Lt,
|
|
sympy.Le: sympy.Ge,
|
|
sympy.Lt: sympy.Gt,
|
|
}
|
|
|
|
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
|
|
|
|
|
|
def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:
|
|
return _MIRROR_REL_OP.get(type, None)
|
|
|
|
|
|
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
|
|
#
|
|
# Returns a tuple of:
|
|
# 1. The simplified expression
|
|
# 2. The expression on the right-hand side
|
|
#
|
|
# Returns 'None' if it can't reach a state where the only thing in the left
|
|
# hand side is 'thing'.
|
|
#
|
|
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
|
|
# left-hand side.
|
|
#
|
|
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
|
|
# inequalities.
|
|
def try_solve(
|
|
expr: sympy.Basic,
|
|
thing: sympy.Basic,
|
|
trials: int = 5,
|
|
floordiv_inequality: bool = True,
|
|
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
|
|
mirror = mirror_rel_op(type(expr))
|
|
|
|
# Ignore unsupported expressions:
|
|
# - Those that are not relational operations
|
|
# - Those that don't have a mirror (just avoiding unexpected classes)
|
|
if not isinstance(expr, sympy.Rel) or mirror is None:
|
|
log.debug("expression with unsupported type: %s", type(expr))
|
|
return None
|
|
|
|
lhs_has_thing = expr.lhs.has(thing)
|
|
rhs_has_thing = expr.rhs.has(thing)
|
|
|
|
# Give up when 'thing' appears on both sides of the relational expression.
|
|
# That is because, as is, we assume the thing we are trying to isolate is
|
|
# only on the right-hand side.
|
|
if lhs_has_thing and rhs_has_thing:
|
|
log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
|
|
return None
|
|
|
|
# Try considering both LHS and RHS by mirroring the original expression:
|
|
# a < b ==> b > a
|
|
expressions = []
|
|
|
|
# Add each version of 'expr' if 'thing' is in its left-hand side.
|
|
if lhs_has_thing:
|
|
expressions.append(expr)
|
|
if rhs_has_thing:
|
|
expressions.append(mirror(expr.rhs, expr.lhs))
|
|
|
|
for e in expressions:
|
|
if e is None:
|
|
continue
|
|
|
|
assert isinstance(e, sympy.Rel)
|
|
|
|
for _ in range(trials):
|
|
trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
|
|
# Stop if there was no change in this trial.
|
|
if trial == e:
|
|
break
|
|
e = trial # type: ignore[assignment]
|
|
|
|
# Return if we were able to isolate 'thing' on the left-hand side.
|
|
if isinstance(e, sympy.Rel) and e.lhs == thing:
|
|
return e, e.rhs
|
|
|
|
return None
|
|
|
|
|
|
def _try_isolate_lhs(
|
|
expr: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
|
|
) -> sympy.Basic:
|
|
e = expr
|
|
op = type(expr)
|
|
|
|
if isinstance(e, sympy.Rel):
|
|
# Move any constants in the left-hand side to the right-hand side.
|
|
lhs_not_thing = (
|
|
sum(a for a in e.lhs.args if not a.has(thing))
|
|
if isinstance(e.lhs, sympy.Add)
|
|
else 0
|
|
)
|
|
e = op(expr.lhs - lhs_not_thing, expr.rhs - lhs_not_thing) # type: ignore[attr-defined]
|
|
|
|
# Divide both sides by the factors that don't contain thing.
|
|
if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
|
|
lhs, rhs = e.args
|
|
other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
|
|
|
|
# If we can't tell whether 'other' is negative or positive, we do nothing.
|
|
# That is because we don't know whether we have mirror the operation or not.
|
|
if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):
|
|
# Divide both sides by 'other'.
|
|
lhs = lhs / other
|
|
rhs = rhs / other
|
|
|
|
# If 'e' is an inequality and 'other' is negative, we have to
|
|
# mirror the expression.
|
|
if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
|
|
op = mirror_rel_op(op) # type: ignore[assignment]
|
|
|
|
assert op is not None
|
|
e = op(lhs, rhs)
|
|
|
|
################################################################################
|
|
# left-hand side is FloorDiv
|
|
################################################################################
|
|
#
|
|
# Given the expression: a // b op c
|
|
# where 'op' is a relational operation, these rules only work if:
|
|
# - b > 0
|
|
# - c is an integer
|
|
if (
|
|
floordiv_inequality
|
|
and isinstance(e, sympy.Rel)
|
|
and isinstance(e.lhs, FloorDiv)
|
|
and e.lhs.divisor.is_positive
|
|
and e.rhs.is_integer
|
|
):
|
|
# a // b == expr
|
|
# => a >= (b * expr) and a < (b * (expr + 1))
|
|
if isinstance(expr, sympy.Eq):
|
|
numerator, denominator = e.lhs.args
|
|
return sympy.And(
|
|
sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
|
|
sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
|
|
)
|
|
# a // b != expr
|
|
# => a < (b * expr) or a >= (b * (expr + 1))
|
|
if isinstance(expr, sympy.Ne):
|
|
numerator, denominator = e.lhs.args
|
|
return sympy.Or(
|
|
sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
|
|
sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
|
|
)
|
|
# The transformations below only work if b is positive.
|
|
# Note: we only have this information for constants.
|
|
# a // b > expr => a >= b * (expr + 1)
|
|
# a // b >= expr => a >= b * expr
|
|
if isinstance(expr, (sympy.Gt, sympy.Ge)):
|
|
quotient = e.rhs if isinstance(expr, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
|
|
return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
|
|
# a // b < expr => a < b * expr
|
|
# a // b <= expr => a < b * (expr + 1)
|
|
if isinstance(expr, (sympy.Lt, sympy.Le)):
|
|
quotient = e.rhs if isinstance(expr, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
|
|
return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
|
|
|
|
return e
|