import sympy from sympy.core.logic import fuzzy_and, fuzzy_or __all__ = ["FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv"] class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 # Default return type for SymPy assumptions. # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers is_real = True @property def base(self): return self.args[0] @property def divisor(self): return self.args[1] def _sympystr(self, printer): base = printer.parenthesize(self.base, self.precedence) divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" # SymPy assumptions based on argument types. def _eval_is_real(self): return fuzzy_or([self.base.is_real, self.divisor.is_real]) def _eval_is_integer(self): return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): def check_supported_type(x): if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean: raise TypeError( f"unsupported operand type(s) for //: " f"'{type(base).__name__}' and '{type(divisor).__name__}'" f", expected integer or real") check_supported_type(base) check_supported_type(divisor) # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. if divisor.is_zero: raise ZeroDivisionError("division by zero") if base.is_zero: return sympy.S.Zero if base.is_integer and divisor == 1: return base if base.is_real and divisor == 1: return sympy.floor(base) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): return base // divisor if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)): return sympy.floor(base / divisor) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd gcd = sympy.gcd(base, divisor) if gcd != 1: return FloorDiv( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) ) class ModularIndexing(sympy.Function): """ ModularIndexing(a, b, c) => (a // b) % c """ nargs = (3,) is_integer = True @classmethod def eval(cls, base, divisor, modulus): if base == 0 or modulus == 1: return sympy.Integer(0) if ( isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer) and isinstance(modulus, sympy.Integer) ): return (base // divisor) % modulus if divisor != 1: gcd = sympy.gcd(base, divisor) if gcd != 1: return ModularIndexing( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd), modulus ) if isinstance(base, sympy.Add): new_terms = [] all_positive = True for term in base.args: if sympy.gcd(term, modulus * divisor) != modulus * divisor: if (isinstance(term, sympy.Integer) and term < 0) or ( isinstance(term, sympy.Mul) and isinstance(term.args[0], sympy.Integer) and term.args[0] < 0 ): # workaround for https://github.com/openai/triton/issues/619, # if there are negative terms, // produces wrong result # TODO if https://github.com/openai/triton/issues/619 is fixed # this optimization would become valid all_positive = False break else: new_terms.append(term) if len(new_terms) != len(base.args) and all_positive: return ModularIndexing(sum(new_terms), divisor, modulus) if isinstance(base, FloorDiv): return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) class CleanDiv(FloorDiv): """ Div where we can assume no rounding. This is to enable future optimizations. """ pass class CeilDiv(sympy.Function): """ Div used in indexing that rounds up. """ is_integer = True def __new__(cls, base, divisor): if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: return FloorDiv(base + (divisor - 1), divisor)