mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano
177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
import logging
|
|
|
|
from typing import Dict, Optional, Tuple, Type
|
|
|
|
import sympy
|
|
|
|
from torch.utils._sympy.functions import FloorDiv
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {
|
|
sympy.Eq: sympy.Eq,
|
|
sympy.Ne: sympy.Ne,
|
|
sympy.Ge: sympy.Le,
|
|
sympy.Gt: sympy.Lt,
|
|
sympy.Le: sympy.Ge,
|
|
sympy.Lt: sympy.Gt,
|
|
}
|
|
|
|
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
|
|
|
|
|
|
def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:
|
|
return _MIRROR_REL_OP.get(type, None)
|
|
|
|
|
|
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
|
|
#
|
|
# Returns a tuple of:
|
|
# 1. The simplified expression
|
|
# 2. The expression on the right-hand side
|
|
#
|
|
# Returns 'None' if it can't reach a state where the only thing in the left
|
|
# hand side is 'thing'.
|
|
#
|
|
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
|
|
# left-hand side.
|
|
#
|
|
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
|
|
# inequalities.
|
|
def try_solve(
|
|
expr: sympy.Basic,
|
|
thing: sympy.Basic,
|
|
trials: int = 5,
|
|
floordiv_inequality: bool = True,
|
|
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
|
|
mirror = mirror_rel_op(type(expr))
|
|
|
|
# Ignore unsupported expressions:
|
|
# - Those that are not relational operations
|
|
# - Those that don't have a mirror (just avoiding unexpected classes)
|
|
if not isinstance(expr, sympy.Rel) or mirror is None:
|
|
log.debug("expression with unsupported type: %s", type(expr))
|
|
return None
|
|
|
|
lhs_has_thing = expr.lhs.has(thing)
|
|
rhs_has_thing = expr.rhs.has(thing)
|
|
|
|
# Give up when 'thing' appears on both sides of the relational expression.
|
|
# That is because, as is, we assume the thing we are trying to isolate is
|
|
# only on the right-hand side.
|
|
if lhs_has_thing and rhs_has_thing:
|
|
log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
|
|
return None
|
|
|
|
# Try considering both LHS and RHS by mirroring the original expression:
|
|
# a < b ==> b > a
|
|
expressions = []
|
|
|
|
# Add each version of 'expr' if 'thing' is in its left-hand side.
|
|
if lhs_has_thing:
|
|
expressions.append(expr)
|
|
if rhs_has_thing:
|
|
expressions.append(mirror(expr.rhs, expr.lhs))
|
|
|
|
for e in expressions:
|
|
if e is None:
|
|
continue
|
|
|
|
assert isinstance(e, sympy.Rel)
|
|
|
|
for _ in range(trials):
|
|
trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
|
|
# Stop if there was no change in this trial.
|
|
if trial == e:
|
|
break
|
|
e = trial # type: ignore[assignment]
|
|
|
|
# Return if we were able to isolate 'thing' on the left-hand side.
|
|
if isinstance(e, sympy.Rel) and e.lhs == thing:
|
|
log.debug("solved: %s ---> %s", expr, e)
|
|
return e, e.rhs
|
|
|
|
return None
|
|
|
|
|
|
def _try_isolate_lhs(
|
|
expr: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
|
|
) -> sympy.Basic:
|
|
e = expr
|
|
op = type(expr)
|
|
|
|
if isinstance(e, sympy.Rel):
|
|
# Move any constants in the left-hand side to the right-hand side.
|
|
lhs_not_thing = (
|
|
sum(a for a in e.lhs.args if not a.has(thing))
|
|
if isinstance(e.lhs, sympy.Add)
|
|
else 0
|
|
)
|
|
e = op(expr.lhs - lhs_not_thing, expr.rhs - lhs_not_thing) # type: ignore[attr-defined]
|
|
|
|
# Divide both sides by the factors that don't contain thing.
|
|
if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
|
|
lhs, rhs = e.args
|
|
other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
|
|
|
|
# If we can't tell whether 'other' is negative or positive, we do nothing.
|
|
# That is because we don't know whether we have mirror the operation or not.
|
|
if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):
|
|
# Divide both sides by 'other'.
|
|
lhs = lhs / other
|
|
rhs = rhs / other
|
|
|
|
# If 'e' is an inequality and 'other' is negative, we have to
|
|
# mirror the expression.
|
|
if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
|
|
op = mirror_rel_op(op) # type: ignore[assignment]
|
|
|
|
assert op is not None
|
|
e = op(lhs, rhs)
|
|
|
|
################################################################################
|
|
# left-hand side is FloorDiv
|
|
################################################################################
|
|
#
|
|
# Given the expression: a // b op c
|
|
# where 'op' is a relational operation, these rules only work if:
|
|
# - b > 0
|
|
# - c is an integer
|
|
if (
|
|
floordiv_inequality
|
|
and isinstance(e, sympy.Rel)
|
|
and isinstance(e.lhs, FloorDiv)
|
|
and e.lhs.divisor.is_positive
|
|
and e.rhs.is_integer
|
|
):
|
|
# a // b == expr
|
|
# => a >= (b * expr) and a < (b * (expr + 1))
|
|
if isinstance(expr, sympy.Eq):
|
|
numerator, denominator = e.lhs.args
|
|
return sympy.And(
|
|
sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
|
|
sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
|
|
)
|
|
# a // b != expr
|
|
# => a < (b * expr) or a >= (b * (expr + 1))
|
|
if isinstance(expr, sympy.Ne):
|
|
numerator, denominator = e.lhs.args
|
|
return sympy.Or(
|
|
sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
|
|
sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
|
|
)
|
|
# The transformations below only work if b is positive.
|
|
# Note: we only have this information for constants.
|
|
# a // b > expr => a >= b * (expr + 1)
|
|
# a // b >= expr => a >= b * expr
|
|
if isinstance(expr, (sympy.Gt, sympy.Ge)):
|
|
quotient = e.rhs if isinstance(expr, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
|
|
return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
|
|
# a // b < expr => a < b * expr
|
|
# a // b <= expr => a < b * (expr + 1)
|
|
if isinstance(expr, (sympy.Lt, sympy.Le)):
|
|
quotient = e.rhs if isinstance(expr, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
|
|
return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
|
|
|
|
return e
|