Complete revamp of float/promotion sympy handling (#126905)

At a high level, the idea behind this PR is:

* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.

The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:

* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)

In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations.  Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.

We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:

* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`

These changes have consequences. First, we need to make some administrative changes:

* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
  * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
  * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.

In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:

* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type

The new asserts uncovered necessary bug fixes:

* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1

Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
This commit is contained in:
Edward Z. Yang 2024-06-05 17:22:11 -07:00 committed by PyTorch MergeBot
parent c1a43a69e4
commit 2f7cfecd86
38 changed files with 1675 additions and 670 deletions

View File

@ -49,15 +49,33 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual SymNode mul(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
// NB: legacy, prefer float_truediv or int_truediv
virtual SymNode truediv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode float_truediv(const SymNode& other) {
return truediv(other);
}
virtual SymNode int_truediv(const SymNode& other) {
return truediv(other);
}
// NB: legacy, prefer float_pow or pow_by_natural
virtual SymNode pow(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode float_pow(const SymNode& other) {
return pow(other);
}
virtual SymNode pow_by_natural(const SymNode& other) {
return pow(other);
}
// NB: legacy, prefer int_floordiv
virtual SymNode floordiv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode int_floordiv(const SymNode& other) {
return floordiv(other);
}
virtual SymNode mod(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}

View File

@ -78,13 +78,6 @@ for test in tests:
del test
if TEST_Z3:
# this only fails when z3 is available
unittest.expectedFailure(
# SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'.
# Ref: https://github.com/sympy/sympy/issues/25146
DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821
)
if not config.inline_inbuilt_nn_modules:
# TODO model is somehow not being freed when z3 is available
unittest.expectedFailure(

View File

@ -2385,8 +2385,7 @@ def forward(self, x):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Constraints violated .*!(.*\n)*.*"
"by dim0 = 2\\*dim1(.*\n)*.*"
"Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*",
"Not all values of dim0 .* satisfy the generated guard 4 <= .* and .* <= 10(.*\n)*.*",
):
torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes)

View File

@ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match:
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Right: {0: 0, 1: 1}
==> var_to_range: values don't match.
> Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
> Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
> Right: {}
==> var_to_sources: values don't match.
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
@ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match:
> Left: 2
> Right: 0
==> var_to_range: values don't match.
> Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)}
> Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)}
> Right: {}
""",
)
@ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match:
> Left: {s0: 3}
> Right: {}
==> var_to_range: values don't match.
> Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
> Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
""",
)
self._replay_and_check(main)
@ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match:
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match.
> Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
> Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
""",
)
self._replay_and_check(main)
@ -9484,10 +9484,7 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match:
==> deferred_runtime_asserts: values don't match.
> Left: {u0: [Eq(Mod(u0, 3), 0)]}
> Right: {}
==> divisible: values don't match.
> Left: {Mod(u0, 3)}
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
> Right: {}
==> name_to_node: values don't match.
> Left: {_assert, eq, mod, u0}

View File

@ -11,7 +11,12 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal
from torch.utils._sympy.functions import (
FloorDiv,
ModularIndexing,
RoundDecimal,
RoundToInt,
)
class TestIndexingSimplification(InductorTestCase):
@ -168,21 +173,11 @@ class ExprPrinterTests(InductorTestCase):
common_cases = [
# expr, result
# Test exprs.
(
s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1),
lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))",
),
(s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"),
# Test Pow directly.
(
sympy.Pow(s1 + s2, 0),
lambda _, L: f"1{L}",
), # note: simplified before _print_Pow
(
sympy.Pow(s1 + s2, -3),
lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))",
),
]
gpu_cases = common_cases + [
@ -231,12 +226,10 @@ class ExprPrinterTests(InductorTestCase):
self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""")
def test_print_round(self):
expr = Round(sympy.Symbol("x", integer=True) / 2)
expr = RoundToInt(sympy.Symbol("x", integer=True) / 2)
self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""")
self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
self.assertExpectedInline(
texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)"""
)
self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
@parametrize("ndigits", [-1, 0, 1])
def test_print_round_decimal(self, ndigits):
@ -251,45 +244,18 @@ class ExprPrinterTests(InductorTestCase):
f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}",
)
expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits)
if ndigits >= 0:
for do_print in [pexpr, cexpr, texpr]:
self.assertEqual(do_print(expr), "x")
else:
self.assertEqual(pexpr(expr), f"round(x, {ndigits})")
for do_print in [cexpr, texpr]:
with self.assertRaisesRegex(
ValueError, "only non-negative ndigits are currently supported"
):
do_print(expr)
def test_print_floor_div(self):
for integer in [True, False]:
s1 = sympy.Symbol("s1", integer=integer)
s2 = sympy.Symbol("s2", integer=integer)
expr = FloorDiv(s1, s2)
self.assertEqual(pexpr(expr), "(s1 // s2)")
if integer:
self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)")
else:
self.assertEqual(
cexpr(expr),
"c10::div_floor_floating(static_cast<double>(s1), static_cast<double>(s2))",
)
s1 = sympy.Symbol("s1", integer=True)
s2 = sympy.Symbol("s2", integer=True)
expr = FloorDiv(s1, s2)
self.assertEqual(pexpr(expr), "(s1 // s2)")
self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)")
for integer in [True, False]:
s1 = sympy.Symbol("s1", integer=integer)
s2 = sympy.S(-1)
expr = FloorDiv(s1, s2)
if integer:
self.assertEqual(pexpr(expr), "(-1)*s1")
self.assertEqual(cexpr(expr), "(-1L)*s1")
else:
self.assertEqual(pexpr(expr), "(s1 // (-1))")
self.assertEqual(
cexpr(expr),
"c10::div_floor_floating(static_cast<double>(s1), static_cast<double>((-1L)))",
)
s1 = sympy.Symbol("s1", integer=True)
s2 = sympy.S(-1)
expr = FloorDiv(s1, s2)
self.assertEqual(pexpr(expr), "(-1)*s1")
self.assertEqual(cexpr(expr), "(-1L)*s1")
def test_print_Min_Max(self):
cases = (

View File

@ -3,6 +3,7 @@ import contextlib
import importlib
import math
import operator
import os
import sys
import unittest
@ -649,6 +650,33 @@ class TestInductorDynamic(TestCase):
actual = cfn(5)
self.assertEqual(expect, actual)
def test_interpolate_ceil_eq(self, device):
ceiling = math.ceil
IntTrueDiv = operator.truediv
def fn(t):
s0, s2, s3 = t.size()
x = torch.zeros(
(
s0,
2048,
ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)),
ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)),
),
dtype=torch.bfloat16,
)
return torch.nn.functional.interpolate(
x,
scale_factor=2,
mode="nearest",
)
cfn = self.compile_fn(fn)
arg = torch.randn(4, 16, 18)
expect = fn(arg)
actual = cfn(arg)
self.assertEqual(expect, actual)
def test_full_recompiles(self, device):
def fn(x):
_, L = x.shape

View File

@ -158,8 +158,12 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
torch.tensor([operator.sub(x.item(), y.item())]),
torch.tensor([operator.mul(x.item(), y.item())]),
torch.tensor([operator.truediv(x.item(), y.item())]),
torch.tensor([operator.floordiv(x.item(), y.item())]),
torch.tensor([operator.pow(x.item(), y.item())]),
# This requires torch.sym_float, probably easy to lower to
# ONNX but I don't know where to put it
# torch.tensor([operator.floordiv(x.item(), y.item())]),
# NB: abs so that the base and exponent are provably
# non-negative, so we don't generate runtime asserts
torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]),
torch.tensor([operator.abs(x.item())]),
torch.tensor([operator.neg(x.item())]),
torch.tensor([math.ceil(x.item())]),

View File

@ -205,15 +205,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True):
# TODO: default duck to False
def create_symint(shape_env, i: int, duck=True):
def create_symint(shape_env, i: int, duck=True) -> SymInt:
return create_symtype(SymInt, int, shape_env, i, duck=duck)
def create_symbool(shape_env, b: bool):
def create_symbool(shape_env, b: bool) -> SymBool:
return create_symtype(SymBool, bool, shape_env, b)
def create_symfloat(shape_env, f: float):
def create_symfloat(shape_env, f: float) -> SymFloat:
return create_symtype(SymFloat, float, shape_env, f)
@ -457,14 +457,16 @@ class TestPySymInt(TestCase):
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""")
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
)
a3 = create_symint(shape_env, 3)
r = sym_int(2.0 * torch.sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)"""
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
)
def test_sym_sqrt(self):
@ -474,7 +476,7 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)"""
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)"""
)
def test_sym_floor(self):
@ -483,11 +485,17 @@ class TestPySymInt(TestCase):
r = math.floor(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""")
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
)
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
)
def test_sym_trunc(self):
shape_env = ShapeEnv()
@ -495,12 +503,14 @@ class TestPySymInt(TestCase):
r = math.trunc(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""")
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
)
r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)"""
str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)"""
)
def test_sym_ceil(self):
@ -510,12 +520,17 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)"""
str(shape_env.guards[0][0]),
"""Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
)
r = math.floor(3.0 * a0)
r1 = 3.0 * a0
r = math.floor(r1)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
)
def test_sym_ite(self):
shape_env = ShapeEnv()
@ -962,8 +977,14 @@ class f(torch.nn.Module):
)
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
# Helper function
# NB: don't use one as that will get specialized
# TODO: We don't have to circuitously create the float, can just
# create a symfloat directly
seed_node = (create_symint(shape_env, 2) / 2.0).node
bool_seed_node = (create_symint(shape_env, 2) == 2).node
@ -976,27 +997,42 @@ class TestSymNumberMagicMethods(TestCase):
else:
return torch.SymFloat(to_node(seed_node, inp))
if fn == "float_pow":
if inp1 < 0:
return
if fn == "pow_by_natural":
if isinstance(inp1, float) or isinstance(inp2, float):
return
if inp2 < 0:
return
def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
elif (
fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
and inp2 == 0
):
# ZeroDivisionError: division by zero
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 == 0 and inp2 < 0:
elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif (
fn == "pow"
# TODO: dear catastrophe waitress,
# this doesn't work
fn in ["float_pow", "pow_by_natural"]
and inp1 < 0
and inp2 in (2.5, -2.5)
and (
type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt)
type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
)
and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
return self.assertRaises((RuntimeError,))
elif fn in ("lshift", "rshift") and not (
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
):
@ -1080,6 +1116,9 @@ class TestSymNumberMagicMethods(TestCase):
) and fn in sym_node.only_float_magic_methods:
self.skipTest(f"{fn} is not an int method")
if second_type == "float" and fn in ["mod"]:
self.skipTest(f"{fn} only handles int")
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
# Second argument is ignored for unary function. So only run for one type
if is_unary_fn and second_type == "float":
@ -1251,112 +1290,15 @@ class TestFloorDiv(TestCase):
yield (-x, -y)
def test_floordiv_float_int(self):
values = (
(2.5, 2.1),
(2.1, 2.5),
(2.0, 2.1),
(7, 2.5),
(2.1, 7),
(7, 2),
)
values = ((7, 2),)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_bool(self):
values = (
(False, True),
(True, 2.5),
(2.5, True),
(False, 7),
(7, True),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# Compares to int since our FloorDiv has no bool support
self.assertEqual(
TestFloorDiv.python_floordiv(x, y),
TestFloorDiv.torch_floordiv(int(x), int(y)),
)
# Tests that our impl throws
self.assertRaisesRegex(
TypeError,
(
rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"
),
lambda: TestFloorDiv.torch_floordiv(x, y),
)
def test_floordiv_complex(self):
values = (
(1.5 + 2.5j, 1.3 + 3.5j),
(1.5 + 2.5j, 2.5),
(2.5, 1.5 + 2.5j),
(1.5 + 2.5j, 7),
(7, 1.5 + 2.5j),
)
for x, y in TestFloorDiv.yield_test_cases(values):
# We don't test error messages to avoid depending on Python
# interpreter version
self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaisesRegex(
TypeError,
(
rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"
),
lambda: TestFloorDiv.torch_floordiv(x, y),
)
def test_floordiv_div_by_zero(self):
values = (
(2.5, 0),
(2.1, 0.0),
(2.3, sympy.Symbol("s", zero=True)),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# We don't test error messages to avoid depending on Python
# interpreter version
if type(y) is not sympy.Symbol:
self.assertRaises(
ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y)
)
self.assertRaisesRegex(
ZeroDivisionError,
"division by zero",
lambda: TestFloorDiv.torch_floordiv(x, y),
)
def test_floordiv_zero_base(self):
values = (
(0, 2.5),
(0.0, 2.1),
(sympy.Symbol("s", zero=True), 2.3),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
if type(x) is not sympy.Symbol:
self.assertEqual(
TestFloorDiv.python_floordiv(x, y),
TestFloorDiv.torch_floordiv(x, y),
)
else:
self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_div_by_one(self):
values = (
(2.5, 1),
(2.1, 1.0),
(2, 1.0),
(2, 1),
)
values = ((2, 1),)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(
@ -1367,12 +1309,7 @@ class TestFloorDiv(TestCase):
# Tests how we simplify or evaluate FloorDiv without free variables
shape_env = ShapeEnv()
result = 21
exprs = (
7 * FloorDiv(6, 2),
7 * FloorDiv(6.28, 2),
7 * FloorDiv(6.28, 2.0),
7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))),
)
exprs = (7 * FloorDiv(6, 2),)
for expr in exprs:
self.assertEqual(expr, result)
@ -1382,33 +1319,10 @@ class TestFloorDiv(TestCase):
self.assertEqual(shape_env.simplify(expr), result)
self.assertEqual(shape_env.evaluate_expr(expr), result)
def test_floordiv_simplify_rational(self):
result = 21
a = sympy.Symbol("a", integer=True)
b = sympy.Symbol("b")
cases = [
(FloorDiv(a, sympy.Rational(1, 8)), 8 * a),
(FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)),
]
for expr, expected in cases:
self.assertEqual(expr, expected)
def test_floordiv_assumptions(self):
# We define two Symbols (with different names) for each type to make
# sure the behavior is consistent regardless of whether both arguments
# are the same object or not.
cases = (
sympy.Symbol("i1", integer=True),
sympy.Symbol("i2", integer=True),
sympy.Symbol("r1", real=True),
sympy.Symbol("r2", real=True),
sympy.Symbol("c1", complex=True, real=False, integer=False),
sympy.Symbol("c2", complex=True, real=False, integer=False),
sympy.Symbol("s1"),
sympy.Symbol("s2"),
)
for base, divisor in itertools.product(cases, repeat=2):

View File

@ -1618,7 +1618,8 @@ def forward(self, lengths_1, values_1):
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
pow_1 = sym_size_int ** 0.5; sym_size_int = None
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
pow_1 = sym_float ** 0.5; sym_float = None
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
return div""")

View File

@ -36,7 +36,12 @@ UNARY_OPS = [
"floor",
"ceil",
]
BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"]
BINARY_OPS = [
"truediv", "floordiv",
# "truncdiv", # TODO
# NB: pow is float_pow
"add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod"
]
UNARY_BOOL_OPS = ["not_"]
BINARY_BOOL_OPS = ["or_", "and_"]
@ -81,16 +86,24 @@ def valid_unary(fn, v):
def valid_binary(fn, a, b):
if fn == "pow" and (
# sympy will expand to x*x*... for integral b; don't do it if it's big
b > 4
or ( # sympy will expand to x*x*... for integral b; don't do it if it's big
a <= 0 and b == -1
)
or (a == b == 0) # no imaginary numbers # 0**0 is undefined
# no imaginary numbers
or a <= 0
# 0**0 is undefined
or (a == b == 0)
):
return False
elif fn == "mod" and b == 0:
elif fn == "pow_by_natural" and (
# sympy will expand to x*x*... for integral b; don't do it if it's big
b > 4
or b < 0
or (a == b == 0)
):
return False
elif (fn == "div" or fn == "truediv") and b == 0:
elif fn == "mod" and (a < 0 or b <= 0):
return False
elif (fn in ["div", "truediv", "floordiv"]) and b == 0:
return False
return True
@ -130,27 +143,26 @@ class TestValueRanges(TestCase):
ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
@parametrize("fn", BINARY_OPS)
@parametrize("dtype_a", ("int", "float"))
@parametrize("dtype_b", ("int", "float"))
def test_binary_ref(self, fn, dtype_a, dtype_b):
@parametrize("dtype", ("int", "float"))
def test_binary_ref(self, fn, dtype):
to_dtype = {"int": sympy.Integer, "float": sympy.Float}
dtype_a = to_dtype[dtype_a]
dtype_b = to_dtype[dtype_b]
# Don't test float on int only methods
if dtype == "float" and fn in ["pow_by_natural", "mod"]:
return
dtype = to_dtype[dtype]
for a, b in itertools.product(CONSTANTS, repeat=2):
if not valid_binary(fn, a, b):
continue
a = dtype_a(a)
b = dtype_b(b)
a = dtype(a)
b = dtype(b)
with self.subTest(a=a, b=b):
r = getattr(ValueRangeAnalysis, fn)(a, b)
if r == ValueRanges.unknown():
continue
ref_r = getattr(ReferenceAnalysis, fn)(a, b)
# sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf
if fn != "floordiv":
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
self.assertEqual(r.lower, r.upper)
self.assertEqual(ref_r, r.lower)
@ -200,7 +212,8 @@ class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
def test_unary_ref_range(self, fn):
vals = [-sympy.oo, *CONSTANTS, sympy.oo]
# TODO: bring back sympy.oo testing for float unary fns
vals = CONSTANTS
for a in generate_range(vals):
with self.subTest(a=a):
ref_r = getattr(ValueRangeAnalysis, fn)(a)
@ -216,40 +229,26 @@ class TestValueRanges(TestCase):
# This takes about 4s for all the variants
@parametrize("fn", BINARY_OPS + COMPARE_OPS)
def test_binary_ref_range(self, fn):
vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo]
# TODO: bring back sympy.oo testing for float unary fns
vals = LESS_CONSTANTS
for a, b in itertools.product(generate_range(vals), repeat=2):
# don't attempt pow on exponents that are too large (but oo is OK)
if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
continue
with self.subTest(a=a, b=b):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
if a0 not in a or b0 not in b:
continue
if not valid_binary(fn, a0, b0):
continue
with self.subTest(a0=a0, b0=b0):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
r = getattr(ReferenceAnalysis, fn)(
sympy.Integer(a0), sympy.Integer(b0)
)
if r.is_finite:
self.assertIn(r, ref_r)
def test_rational_bounds(self):
# Repro from https://github.com/pytorch/pytorch/issues/105097
from sympy import floor, Eq
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
new_expr = (
Eq(30 * floor(4 * ((shape_0 + 1) // 96) *
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 +
2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647),
2880 * floor(((shape_0 + 1) // 96) *
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 +
323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764)))
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
self.assertTrue(new_expr.subs({shape_0: 95}))
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
class TestSympyInterp(TestCase):
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
@ -258,7 +257,13 @@ class TestSympyInterp(TestCase):
if fn in ("div", "truncdiv", "minimum", "maximum", "mod"):
return
from sympy.abc import x, y
is_integer = None
if fn == "pow_by_natural":
is_integer = True
x = sympy.Dummy('x', integer=is_integer)
y = sympy.Dummy('y', integer=is_integer)
vals = CONSTANTS
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
vals = [True, False]
@ -300,29 +305,17 @@ class TestSympyInterp(TestCase):
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
arity = 2
from sympy.abc import x, y
is_integer = None
if fn == "pow_by_natural":
is_integer = True
x = sympy.Dummy('x', integer=is_integer)
y = sympy.Dummy('y', integer=is_integer)
symbols = [x]
if arity == 2:
symbols = [x, y]
# Workaround mpf from symbol error
if fn == "minimum":
sympy_expr = sympy.Min(x, y)
elif fn == "maximum":
sympy_expr = sympy.Max(x, y)
else:
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
if arity == 1:
def trace_f(px):
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
else:
def trace_f(px, py):
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
gm = fx.symbolic_trace(trace_f)
for args in itertools.product(vals, repeat=arity):
if arity == 1 and not valid_unary(fn, *args):
continue
@ -330,11 +323,28 @@ class TestSympyInterp(TestCase):
continue
if fn == "truncdiv" and args[1] == 0:
continue
elif fn == "pow" and (args[0] == 0 and args[1] <= 0):
elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0):
continue
elif fn == "floordiv" and args[1] == 0:
continue
with self.subTest(args=args):
# Workaround mpf from symbol error
if fn == "minimum":
sympy_expr = sympy.Min(x, y)
elif fn == "maximum":
sympy_expr = sympy.Max(x, y)
else:
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
if arity == 1:
def trace_f(px):
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
else:
def trace_f(px, py):
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
gm = fx.symbolic_trace(trace_f)
self.assertEqual(
sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr),
gm(*args)

View File

@ -316,6 +316,75 @@ class SymInt:
# Magic methods installed by torch.fx.experimental.sym_node
def __round__(self, ndigits=None):
return self
def __truediv__(self, other):
if isinstance(other, (builtins.float, SymFloat)):
return sym_float(self).__float_truediv__(other)
if not isinstance(other, (builtins.int, SymInt)):
return NotImplemented
return self.__int_truediv__(other)
def __rtruediv__(self, other):
if isinstance(other, (builtins.float, SymFloat)):
return sym_float(self).__rfloat_truediv__(other)
if not isinstance(other, (builtins.int, SymInt)):
return NotImplemented
return self.__rint_truediv__(other)
def __floordiv__(self, other):
if isinstance(other, (builtins.float, SymFloat)):
return torch.sym_float(math.floor(sym_float(self) / other))
if not isinstance(other, (builtins.int, SymInt)):
return NotImplemented
return self.__int_floordiv__(other)
def __rfloordiv__(self, other):
if isinstance(other, (builtins.float, SymFloat)):
return torch.sym_float(math.floor(other / sym_float(self)))
if not isinstance(other, (builtins.int, SymInt)):
return NotImplemented
return self.__rint_floordiv__(other)
# nb: complex is impossible to handle correctly lol, with
# negative base and integral float need to diverge semantics and
# just always return complex. Neener neener pretend this problem
# doesn't exist
def __pow__(self, other):
if isinstance(other, (builtins.float, SymFloat)):
return sym_float(self).__pow__(other)
if not isinstance(other, (builtins.int, SymInt)):
return NotImplemented
# Guards! This guard is necessary because we need to know it to
# determine the output type of this operation
if other >= 0:
return self.__pow_by_natural__(other)
else:
# Mercifully, when the exponent is negative, Python just promotes
# to doubles and does a float pow:
#
# if (Py_SIZE(b) < 0 && c == NULL) {
# /* if exponent is negative and there's no modulus:
# return a float. This works because we know
# that this calls float_pow() which converts its
# arguments to double. */
# Py_DECREF(a);
# Py_DECREF(b);
# return PyFloat_Type.tp_as_number->nb_power(v, w, x);
# }
return sym_float(self).__pow__(sym_float(other))
def __rpow__(self, other):
if isinstance(other, (builtins.float, SymFloat)):
return sym_float(self).__rpow__(other)
if not isinstance(other, (builtins.int, SymInt)):
return NotImplemented
if self >= 0: # self is exponent
return self.__rpow_by_natural__(other)
else:
return sym_float(self).__rpow__(sym_float(other))
def __eq__(self, other: object) -> builtins.bool:
raise AssertionError("type stub not overridden")
@ -337,6 +406,24 @@ class SymInt:
def __mul__(self, other) -> "SymInt":
raise AssertionError("type stub not overridden")
def __pow_by_natural__(self, other) -> "SymInt":
raise AssertionError("type stub not overridden")
def __rpow_by_natural__(self, other) -> "SymInt":
raise AssertionError("type stub not overridden")
def __int_truediv__(self, other) -> "SymFloat":
raise AssertionError("type stub not overridden")
def __rint_truediv__(self, other) -> "SymFloat":
raise AssertionError("type stub not overridden")
def __int_floordiv__(self, other) -> "SymFloat":
raise AssertionError("type stub not overridden")
def __rint_floordiv__(self, other) -> "SymFloat":
raise AssertionError("type stub not overridden")
def __sym_max__(self, other):
raise AssertionError("type stub not overridden")
@ -371,9 +458,43 @@ class SymFloat:
# class has a field named node that stores SymNode
self.node = node
def __truediv__(self, other):
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
return NotImplemented
return self.__float_truediv__(sym_float(other))
def __rtruediv__(self, other):
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
return NotImplemented
return self.__rfloat_truediv__(sym_float(other))
def __floordiv__(self, other):
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
return NotImplemented
return torch.sym_float(math.floor(self / sym_float(other)))
def __rfloordiv__(self, other):
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
return NotImplemented
return torch.sym_float(math.floor(sym_float(other) / self))
def __bool__(self):
return self.node.bool_()
# Symbolic power does NOT work with negative base, this is to avoid
# potential complex outputs
def __pow__(self, other):
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
return NotImplemented
torch._check(self >= 0)
return self.__float_pow__(other)
def __rpow__(self, other):
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
return NotImplemented
torch._check(other >= 0)
return self.__rfloat_pow__(other)
# Magic methods installed by torch.fx.experimental.sym_node
def __eq__(self, other: object) -> builtins.bool:
@ -391,6 +512,18 @@ class SymFloat:
def __ge__(self, other) -> builtins.bool:
raise AssertionError("type stub not overridden")
def __float_pow__(self, other) -> "SymFloat":
raise AssertionError("type stub not overridden")
def __rfloat_pow__(self, other) -> "SymFloat":
raise AssertionError("type stub not overridden")
def __float_truediv__(self, other) -> "SymFloat":
raise AssertionError("type stub not overridden")
def __rfloat_truediv__(self, other) -> "SymFloat":
raise AssertionError("type stub not overridden")
def __trunc__(self):
raise AssertionError("type stub not overridden")
@ -524,7 +657,12 @@ def sym_int(a):
return py_int(a) # type: ignore[operator]
def sym_max(a, b):
""" SymInt-aware utility for max()."""
"""
SymInt-aware utility for max which avoids branching on a < b.
Unlike builtins.max(), this only works for int/float, and it always
promotes to float if any argument is float (unlike builtins.max, which
will faithfully preserve the type of the input argument).
"""
from .overrides import has_torch_function, handle_torch_function
if has_torch_function((a, b)):
@ -532,14 +670,19 @@ def sym_max(a, b):
if isinstance(a, (SymInt, SymFloat)):
return a.__sym_max__(b)
elif isinstance(b, (SymInt, SymFloat)):
# NB: If you actually care about preserving output type exactly
# if you do something like max(0, 0.0), it is NOT sound to treat
# min/max as commutative
# Due to promotion semantics, this is operator is commutative:
# max(1, 1.0) === max(1.0, 1) === 1.0
return b.__sym_max__(a)
return builtins.max(a, b) # type: ignore[operator]
# TODO: Probably can make bool work too, just lazy
assert isinstance(a, (builtins.int, builtins.float)), type(a)
assert isinstance(b, (builtins.int, builtins.float)), type(b)
if isinstance(a, builtins.float) or isinstance(b, builtins.float):
return builtins.float(builtins.max(a, b))
else:
return builtins.max(a, b)
def sym_min(a, b):
""" SymInt-aware utility for max()."""
""" SymInt-aware utility for min()."""
from .overrides import has_torch_function, handle_torch_function
if has_torch_function((a, b)):
@ -548,7 +691,12 @@ def sym_min(a, b):
return a.__sym_min__(b)
elif isinstance(b, (SymInt, SymFloat)):
return b.__sym_min__(a)
return builtins.min(a, b) # type: ignore[operator]
assert isinstance(a, (builtins.int, builtins.float)), type(a)
assert isinstance(b, (builtins.int, builtins.float)), type(b)
if isinstance(a, builtins.float) or isinstance(b, builtins.float):
return builtins.float(builtins.min(a, b))
else:
return builtins.min(a, b)
# Drop in replacement for math.sqrt, math.sin, math.cos etc
current_module = sys.modules[__name__]

View File

@ -1474,10 +1474,15 @@ class GraphModuleDeserializer(metaclass=Final):
# Here we force symbols corresponding to SymInts to be at least integers.
# Otherwise some expressions that the shape env would otherwise evaluate to False,
# e.g., 2*s = 9, can have rational solutions, e.g., 9/2.
# TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024)
sym = sym.subs(
{s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols}
)
if isinstance(sym, sympy.Symbol):
# We need to check if the symbol has already been allocated,
# self.symbol_name_to_symbol is not enough because the
# integer-ification of symbols can induce simplification;
# e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral
if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val:
self.symbol_name_to_symbol[val.expr_str] = sym
if hint is not None:
self.shape_env.add_var_to_val(sym, hint)
@ -1496,7 +1501,7 @@ class GraphModuleDeserializer(metaclass=Final):
free_symbols = sym.free_symbols
for s in free_symbols:
if s.name not in self.symbol_name_to_symbol:
self.symbol_name_to_symbol[s.name] = s
self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment]
if vr := self.symbol_name_to_range.get(s.name):
self.shape_env.constrain_symbol_range(
s,

View File

@ -1,3 +1,4 @@
import logging
import operator
from functools import partial
from typing import Any, Callable, Dict
@ -11,6 +12,9 @@ from .utils import cache_on_self, dominated_nodes
from .virtualized import V
log = logging.getLogger(__name__)
class BoundVars:
"""
Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
@ -55,6 +59,7 @@ class BoundVars:
with V.set_ops_handler(ValueRangeAnalysis()):
interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
log.debug("get_bounds:\n%s", self.loop_body.root_block.graph)
interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
return self._bounds

View File

@ -340,6 +340,8 @@ class DataTypePropagation:
DataTypePropagation.propagate_loopbody(node._body)
# This printer contains rules that are supposed to be generic for both C/C++ and
# Python
class ExprPrinter(Printer):
@staticmethod
def paren(string):
@ -369,12 +371,6 @@ class ExprPrinter(Printer):
return string
return f"({string})"
def _print_Infinity(self, expr):
return "math.inf"
def _print_NegativeInfinity(self, expr):
return "-math.inf"
def _print_Relational(self, expr):
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
@ -384,11 +380,14 @@ class ExprPrinter(Printer):
def _print_Add(self, expr):
return " + ".join(map(self.paren, map(self._print, expr.args)))
# NB: this is OK to put here, because Mod is only defined for positive
# numbers, and so across C/Python its behavior is consistent
def _print_Mod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_FloorDiv(self, expr):
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr)
@ -399,10 +398,84 @@ class ExprPrinter(Printer):
# Go figure...
return " >= ".join(map(self.paren, map(self._print, expr.args)))
# NB: The C implementation is injected into codegen at
# torch/_inductor/codegen/wrapper.py
def _print_align(self, expr):
assert len(expr.args) == 1
return f"align({self._print(expr.args[0])})"
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
# any explicit intervention. We print it just like x * x, notably, we
# never generate sympy.Pow with floats.
#
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
def _print_Pow(self, expr):
base, exp = expr.args
base = self._print(base)
assert exp == int(exp), exp
exp = int(exp)
assert exp >= 0
if exp > 0:
return "*".join([self.paren(base)] * exp)
else: # exp == 0
return "1"
# Explicit NotImplemented functions are to prevent default sympy printing
# behavior, which will just barf out ToFloat(...) to your IR. The error
# message is better here because it tells you which printer class it needs
# to go in.
def _print_ToFloat(self, expr):
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
def _print_Infinity(self, expr):
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
def _print_NegativeInfinity(self, expr):
raise NotImplementedError(
f"_print_NegativeInfinity not implemented for {type(self)}"
)
def _print_FloorDiv(self, expr):
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_PythonMod(self, expr):
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
def _print_IntTrueDiv(self, expr):
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr):
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
def _print_TruncToInt(self, expr):
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
def _print_RoundToInt(self, expr):
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
def _print_RoundDecimal(self, expr):
raise NotImplementedError(
f"_print_RoundDecimal not implemented for {type(self)}"
)
# NB: Some float operations are INTENTIONALLY not implemented for
# printers. You can implement them as a quick unblock, but it is better
# to ask yourself why we haven't done this computation in the Tensor
# universe instead
def _print_TruncToFloat(self, expr):
raise NotImplementedError(
f"_print_TruncToFloat not implemented for {type(self)}"
)
def doprint(self, expr, *, simplify: bool = True):
# TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
@ -411,6 +484,10 @@ class ExprPrinter(Printer):
class PythonPrinter(ExprPrinter):
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"float({self._print(expr.args[0])})"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
@ -420,56 +497,72 @@ class PythonPrinter(ExprPrinter):
x = f"({x} // {div})"
return f"{x} % {mod}"
def _print_Infinity(self, expr):
return "math.inf"
def _print_NegativeInfinity(self, expr):
return "-math.inf"
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
# does a special algorithm
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _helper_sqrt(self, expr):
return f"math.sqrt({self._print(expr)})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return self._helper_sqrt(expr.args[0])
def _print_Pow(self, expr):
# Pow() confuses triton
def _print_FloatPow(self, expr):
base, exp = expr.args
# NB: Remember this is sizevar computation! You don't typically
# expect to have to do floating point computation including exponents
# in sizevar compute. Instead of adding support for floating
# point pow, you should make upstream retranslate the Sympy expression
# into Tensor expressions earlier and do that instead.
if exp == 0.5:
return self._helper_sqrt(base)
elif exp == -0.5:
return "1/" + self._helper_sqrt(base)
base = self._print(base)
assert exp == int(exp), exp
exp = int(exp)
if exp > 0:
return "*".join([self.paren(base)] * exp)
elif exp < 0:
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
return "1"
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
# TODO: Not sure this works with Triton, even when base/exp are integral
def _print_PowByNatural(self, expr):
base, exp = expr.args
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
def _print_floor(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_Trunc(self, expr):
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
# This also could have been int(), they'll do the same thing for float
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
# NB: It's expected that we've made explicit any promotion in the sympy
# expression, so it doesn't matter that Python max/min doesn't perform
# promotion
def _print_Max(self, expr):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
@ -514,7 +607,7 @@ class PythonPrinter(ExprPrinter):
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_Round(self, expr):
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
@ -653,6 +746,29 @@ class OpOverrides:
)
return ops.where(cond, ops.add(r, b), r)
@staticmethod
def trunc_to_int(a, dtype):
return ops.to_dtype(ops.trunc(a), dtype)
@staticmethod
def floor_to_int(a, dtype):
return ops.to_dtype(ops.floor(a), dtype)
@staticmethod
def ceil_to_int(a, dtype):
return ops.to_dtype(ops.ceil(a), dtype)
@staticmethod
def round_to_int(a, dtype):
return ops.to_dtype(ops.round(a), dtype)
@staticmethod
def int_truediv(a, b):
# TODO: this is wrong
# TODO: an easy bandaid is to generate runtime asserts that it's
# <= 2**53, which is when this equation is correct
return ops.truediv(a, b)
@staticmethod
def load_seed(name, offset):
return ops.load(name, sympy.Integer(offset))

View File

@ -275,11 +275,11 @@ def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length:
original_index = index
div = sympy.Wild("divisor")
div = sympy.Wild("divisor", integer=True)
if index.has(FloorDiv):
index = index.replace(FloorDiv(var, div), visit_indexing_div)
mod = sympy.Wild("modulus")
mod = sympy.Wild("modulus", integer=True)
if index.has(ModularIndexing):
index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing)

View File

@ -100,10 +100,53 @@ class CppPrinter(ExprPrinter):
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Trunc(self, expr):
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
return f"static_cast<{INDEX_TYPE}>({r})"
def _print_TruncToFloat(self, expr):
assert len(expr.args) == 1
return f"std::trunc({self._print(expr.args[0])})"
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right).
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_CMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
# TODO: This is only accurate up to 2**53
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_FloatPow(self, expr):
base, exp = expr.args
return f"std::pow({self._print(base)}, {self._print(exp)})"
def _print_Pow(self, expr):
# Uses float constants to perform FP div
@ -139,6 +182,11 @@ class CppPrinter(ExprPrinter):
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
@ -200,8 +248,9 @@ class CppPrinter(ExprPrinter):
def _print_OpaqueUnaryFn_sqrt(self, expr):
return f"std::sqrt({self._print(expr.args[0])})"
def _print_Round(self, expr):
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):

View File

@ -272,23 +272,68 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]):
return f"{value}[{', '.join(expand)}]"
# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
# number of operators which Triton "implements", but in a way that is
# inconsistent with Python semantics (and consistent with C semantics). We
# must override all of these, or it is potential silent correctness problem
class TritonPrinter(PythonPrinter):
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
return (
f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right). If you are trying to hit this, maybe try something like
# torch.arange(n, device="cuda") - 1 and then do a modulus on it
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
# TODO: This is wrong, see
# https://github.com/triton-lang/triton/issues/955
# But for Sympy expressions, things will /mostly/ work out because we
# don't usually deal with negative numbers in the division
def _print_FloorDiv(self, expr):
assert expr.is_integer
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
# precision algorithm, which we would need to replicate here
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
# NB: sympy.floor/ceiling produce integers, so we have to do the
# conversion to index dtype
def _print_floor(self, expr):
assert len(expr.args) == 1
return (
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_Trunc(self, expr):
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
return (
f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
def _helper_sqrt(self, expr):
return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))"
@ -359,20 +404,9 @@ class TritonPrinter(PythonPrinter):
assert len(expr.args) == 1
return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
def _print_FloorDiv(self, expr):
if expr.is_integer:
return super()._print_FloorDiv(expr)
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})"
def _print_Round(self, expr):
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
return (
f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
return f"libdevice.llrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2

View File

@ -1196,8 +1196,11 @@ class GraphLowering(torch.fx.Interpreter):
elif is_magic_method(n.target):
# TODO: this is sus, it probably should be handled in the
# lowerings themselves similarly to sym_size/sym-stride
# https://github.com/pytorch/pytorch/issues/127789
debug("is_magic_method")
if isinstance(n.meta["val"], torch.SymInt):
if isinstance(
n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
):
result = n.meta["val"].node.expr
else:
result = super().run_node(n)

View File

@ -44,7 +44,6 @@ from torch._prims_common import (
is_boolean_dtype,
is_float_dtype,
make_channels_last_strides_for,
make_contiguous_strides_for,
StrideType,
)
from torch._subclasses.fake_tensor import get_schema_info
@ -236,7 +235,7 @@ def ir_node_to_tensor(x, guard_shape=True):
if is_storage_and_layout(x):
stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc]
else:
stride = make_contiguous_strides_for(size) # type: ignore[arg-type]
stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type]
dtype = x.get_dtype()
device = x.get_device()
size = convert_shape_to_symint(size)
@ -2766,6 +2765,7 @@ class FlexibleLayout(Layout):
allow_indexing = False
# WARNING! This doesn't handle zero size tensors correctly
@staticmethod
def contiguous_strides(sizes):
if len(sizes) == 0:
@ -5915,7 +5915,7 @@ def _prepare_convolution_fusion_create(
# To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last.
dynamic_shapes = not all(isinstance(i, int) for i in (output_size))
if dynamic_shapes and is_contiguous_storage_and_layout(x):
output_stride = make_contiguous_strides_for(output_size)
output_stride = FlexibleLayout.contiguous_strides(output_size)
else:
output_stride = make_channels_last_strides_for(output_size)
@ -5967,7 +5967,7 @@ def _prepare_linear_fusion_create(
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
inputs = [x, weight]
output_stride = make_contiguous_strides_for(output_size)
output_stride = FlexibleLayout.contiguous_strides(output_size)
kernel_layout = FixedLayout(
x.get_device(),
x.get_dtype(),
@ -6283,7 +6283,7 @@ class MKLPackedLinear(ExternKernelAlloc):
*m, _ = x.get_size()
oc, _ = orig_w.get_size()
output_size = list(m) + [oc]
output_stride = make_contiguous_strides_for(output_size)
output_stride = FlexibleLayout.contiguous_strides(output_size)
inputs = [x, packed_w, orig_w]
constant_args = [batch_size]
if B is not None:
@ -6601,13 +6601,13 @@ class MkldnnRnnLayer(ExternKernelAlloc):
def get_strides_of_lstm_output(output_shape, batch_first):
assert len(output_shape) == 3, "Expect output_shape to be 3D"
return make_contiguous_strides_for(output_shape)
return FlexibleLayout.contiguous_strides(output_shape)
output_sizes = [output_shape, hy_shape, cy_shape]
output_strides = [
get_strides_of_lstm_output(output_shape, batch_first),
make_contiguous_strides_for(hy_shape),
make_contiguous_strides_for(cy_shape),
FlexibleLayout.contiguous_strides(hy_shape),
FlexibleLayout.contiguous_strides(cy_shape),
]
output_ir = [
MultiOutput(

View File

@ -5,7 +5,6 @@ from enum import auto, Enum
from typing import Any, List, Tuple
import torch
from torch._prims_common import make_contiguous_strides_for
from .. import config
from ..ir import (
ComputedBuffer,
@ -389,7 +388,7 @@ def flex_attention(*args, **kwargs):
query.get_device(),
query.get_dtype(),
query.get_size(),
make_contiguous_strides_for(query.get_size()),
FlexibleLayout.contiguous_strides(query.get_size()),
)
# see NOTE:[TritonTemplates with multiple outputs]
logsumexp_shape = query.get_size()[:-1] # [B, H, M]
@ -745,7 +744,7 @@ def flex_attention_backward(*args, **kwargs):
key.get_device(),
key.get_dtype(),
key.get_size(),
make_contiguous_strides_for(key.get_size()),
FlexibleLayout.contiguous_strides(key.get_size()),
)
# Create delta which will is needed for the bwd's kernel

View File

@ -34,7 +34,7 @@ from torch._prims_common import (
Number,
)
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing
from .._dynamo.utils import import_submodule
from . import config, inductor_prims, ir, test_operators # NOQA: F401
@ -4262,7 +4262,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim):
out_sz = out_sz[dim]
in_sz = in_sz[dim]
kernel_sz = kernel_sz[dim]
alpha = (in_sz - kernel_sz) / (out_sz - 1)
alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1)
samples_loader = samples.make_loader()
def load(prefix, i):
@ -4372,7 +4372,7 @@ def upsample_nearest2d_backward(
w_kernel_max = ceildiv(inp_w, out_w)
def start_index(index, out_dim, inp_dim):
return CeilDiv(index * inp_dim, out_dim)
return CeilDiv(index * inp_dim, sympy.sympify(out_dim))
def end_index(index, out_dim, inp_dim):
return start_index((index + 1), out_dim, inp_dim)

View File

@ -138,6 +138,38 @@ class OpsHandler(Protocol[T]):
"""
...
def trunc_to_int(self, x: T, dtype: torch.dtype) -> T:
"""
Convert x to dtype with truncation semantics (similar to how the int
constructor works in Python). In Inductor codegen, this just decays
to trunc and then to_dtype, but this composite operation helps
roundtrips for Sympy evaluation.
dtype is taken as an explicit parameter because the desired output
dtype is typically the index dtype, which may vary between int32 and
int64 depending on if we've shown that all the indexing operations can
be done in int32.
"""
...
def ceil_to_int(self, x: T, dtype: torch.dtype) -> T:
"""
Convert x to dtype with ceiling semantics. See also trunc_to_int.
"""
...
def floor_to_int(self, x: T, dtype: torch.dtype) -> T:
"""
Convert x to dtype with ceiling semantics. See also trunc_to_int.
"""
...
def round_to_int(self, x: T, dtype: torch.dtype) -> T:
"""
Convert x to dtype with round-to-even semantics. See also trunc_to_int.
"""
...
def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
"""
Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
@ -398,21 +430,23 @@ class OpsHandler(Protocol[T]):
def isnan(self, x0: T) -> T:
...
# NB: this returns a float, like the torch operation
# This rounds half to even to break ties
def round(self, x0: T) -> T:
...
# NB: this returns a float, like the torch operation
def floor(self, x0: T) -> T:
...
def sign(self, x0: T) -> T:
...
def to_int(self, x0: T) -> T:
...
# NB: this returns a float, like the torch operation
def trunc(self, x0: T) -> T:
...
# NB: this returns a float, like the torch operation
def ceil(self, x0: T) -> T:
...
@ -449,6 +483,7 @@ class OpsHandler(Protocol[T]):
def mul(self, x0: T, x1: T) -> T:
...
# NB: this returns a float, like the torch operation
def pow(self, x0: T, x1: T) -> T:
...
@ -617,14 +652,21 @@ class OpsHandler(Protocol[T]):
def floordiv(self, x0: T, x1: T) -> T:
"""Python-style floor division between integers only. Computes the
true division of two numbers and floors the result.
true division of two numbers and floors the result. If you want
floor division for floats, do regular truediv and floor the result.
"""
...
def truediv(self, x0: T, x1: T) -> T:
"""True division between floats. Integer inputs are NOT valid: to do
Python style (int, int) -> float division, promote the inputs to float
first."""
"""True division between floats. Integer inputs are NOT valid. To
do Python-style (int, int) -> float division, use int_truediv"""
...
def int_truediv(self, x0: T, x1: T) -> T:
"""True division between integers. This is NOT the same as promoting
to float and doing integer division, there is a bespoke algorithm for
doing the division in higher precision than the above.
"""
...
def div(self, x0: T, x1: T) -> T:
@ -640,6 +682,10 @@ class OpsHandler(Protocol[T]):
"""Python-style modulus, take sign from RHS (x1)."""
...
def round_decimal(self, x0: T, x1: T) -> T:
"""Python-style round with decimal argument"""
...
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# In CUDA, optimized implementations of other mathematical operations are
# offered separately via libdevice for double precision computation (in

View File

@ -386,7 +386,7 @@ class TritonTemplateKernel(TritonKernel):
assert isinstance(mask, (str, type(None)))
assert self.template_mask is None
indices = list(map(TritonPrinter.paren, indices))
index_symbols = [sympy.Symbol(x) for x in indices]
index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
lengths = [
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
]
@ -410,7 +410,7 @@ class TritonTemplateKernel(TritonKernel):
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
output_index = self.rename_indexing(output_index)
if output_index == contiguous_index:
output_index = sympy.Symbol("xindex")
output_index = sympy.Symbol("xindex", integer=True)
epilogue_args = [val]
for input_node in itertools.chain(

View File

@ -161,9 +161,9 @@ class SizeVarAllocator:
if expr.has(ModularIndexing):
expr = expr.replace(
ModularIndexing(
sympy.Wild("base"),
sympy.Wild("divisor"),
sympy.Wild("modulus"),
sympy.Wild("base", integer=True),
sympy.Wild("divisor", integer=True),
sympy.Wild("modulus", integer=True),
),
visit_modular_indexing,
)
@ -171,8 +171,8 @@ class SizeVarAllocator:
if expr.has(FloorDiv):
expr = expr.replace(
FloorDiv(
sympy.Wild("base"),
sympy.Wild("divisor"),
sympy.Wild("base", integer=True),
sympy.Wild("divisor", integer=True),
),
visit_indexing_div,
)
@ -604,11 +604,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr:
"""
assert isinstance(expr, sympy.Add)
scale = sympy.Wild("scale", exclude=[0])
base = sympy.Wild("base")
divisor = sympy.Wild("divisor")
mod1 = sympy.Wild("modulus")
mod2 = sympy.Wild("modulus2")
scale = sympy.Wild("scale", exclude=[0], integer=True)
base = sympy.Wild("base", integer=True)
divisor = sympy.Wild("divisor", integer=True)
mod1 = sympy.Wild("modulus", integer=True)
mod2 = sympy.Wild("modulus2", integer=True)
for term1 in expr.args:
m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
if m1:

View File

@ -192,7 +192,7 @@ def ceildiv(
numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
) -> Union[int, sympy.Expr]:
if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
return CeilDiv(numer, denom)
return CeilDiv(sympy.sympify(numer), sympy.sympify(denom))
# TODO: There is a bug in a call to this function, to repro:
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
# --amp --only YituTechConvBert --dynamic-shapes

View File

@ -1727,7 +1727,7 @@ class FakeTensorMode(TorchDispatchMode):
for run_impl_check, op_impl in op_implementations_checks:
if run_impl_check(func):
op_impl_out = op_impl(self, func, *args, **kwargs)
if op_impl_out != NotImplemented:
if op_impl_out is not NotImplemented:
return maybe_propagate_real_tensors(op_impl_out)
def maybe_run_unsafe_fallback(error=None):

View File

@ -1200,8 +1200,13 @@ void initJITBindings(PyObject* module) {
SYMNODE_BINARY(sub)
SYMNODE_BINARY(mul)
SYMNODE_BINARY(truediv)
SYMNODE_BINARY(int_truediv)
SYMNODE_BINARY(float_truediv)
SYMNODE_BINARY(pow)
SYMNODE_BINARY(float_pow)
SYMNODE_BINARY(pow_by_natural)
SYMNODE_BINARY(floordiv)
SYMNODE_BINARY(int_floordiv)
SYMNODE_BINARY(mod)
SYMNODE_BINARY(eq)
SYMNODE_BINARY(ne)

View File

@ -198,14 +198,34 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
return dispatch_common_(__func__, other);
}
c10::SymNode float_truediv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode int_truediv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode pow(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode float_pow(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode pow_by_natural(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode floordiv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode int_floordiv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode mod(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}

View File

@ -1,7 +1,6 @@
import builtins
import dataclasses
import inspect
import math
import sys
import weakref
from collections import defaultdict
@ -254,11 +253,14 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
shared: Optional[_ConstraintTarget] = None
debug_name: Optional[str] = None
def _clone_with_range(self, lower=0, upper=math.inf):
def _clone_with_range(self, lower=0, upper=None):
# Import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.value_ranges import ValueRanges
if upper is None:
upper = sys.maxsize - 1
constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
warn_only=False,
@ -486,7 +488,6 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
)
# Import sympy locally
import sympy
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.value_ranges import ValueRanges
@ -496,7 +497,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
id(t),
index,
StrictMinMaxConstraint(
vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False
vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False
),
debug_name=debug_name,
)

View File

@ -277,7 +277,13 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
raise
except Exception:
log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs)
log.error( # noqa: G201
"failed while running %s(*%s, **%s)",
name,
args[1:],
kwargs,
exc_info=log.isEnabledFor(logging.INFO),
)
raise
return wrapper

View File

@ -267,8 +267,11 @@ class SymNode:
def mod(self, other) -> "SymNode":
return self._mod(other) # type: ignore[attr-defined]
def pow(self, other) -> "SymNode":
return self._pow(other) # type: ignore[attr-defined]
def float_pow(self, other) -> "SymNode":
return self._float_pow(other) # type: ignore[attr-defined]
def pow_by_natural(self, other) -> "SymNode":
return self._pow_by_natural(other) # type: ignore[attr-defined]
def and_(self, other) -> "SymNode":
return self._and_(other) # type: ignore[attr-defined]
@ -276,11 +279,14 @@ class SymNode:
def or_(self, other) -> "SymNode":
return self._or_(other) # type: ignore[attr-defined]
def truediv(self, other) -> "SymNode":
return self._truediv(other) # type: ignore[attr-defined]
def float_truediv(self, other) -> "SymNode":
return self._float_truediv(other) # type: ignore[attr-defined]
def floordiv(self, other) -> "SymNode":
return self._floordiv(other) # type: ignore[attr-defined]
def int_truediv(self, other) -> "SymNode":
return self._int_truediv(other) # type: ignore[attr-defined]
def int_floordiv(self, other) -> "SymNode":
return self._int_floordiv(other) # type: ignore[attr-defined]
def lshift(self, other) -> "SymNode":
return self._lshift(other) # type: ignore[attr-defined]
@ -361,6 +367,17 @@ class SymNode:
def sym_and(self, other):
return self.and_(other)
# There is no int_truediv available from C++
def truediv(self, other):
return self.float_truediv(other)
def floordiv(self, other) -> "SymNode":
return self.int_floordiv(other)
# We didn't bind integer pow in C++
def pow(self, other):
return self.float_pow(other)
def is_non_overlapping_and_dense(self, sizes, strides):
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
@ -477,7 +494,7 @@ METHOD_TO_OPERATOR = {
"eq": operator.eq,
"floor": math.floor,
"trunc": math.trunc,
"floordiv": operator.floordiv,
"int_floordiv": operator.floordiv,
"ge": operator.ge,
"gt": operator.gt,
"is_integer": lambda x: x.is_integer(),
@ -489,7 +506,8 @@ METHOD_TO_OPERATOR = {
"ne": operator.ne,
"neg": operator.neg,
"or": operator.or_,
"pow": operator.pow,
"float_pow": operator.pow,
"pow_by_natural": operator.pow,
"round": builtins.round,
"rshift": operator.rshift,
"sub": operator.sub,
@ -498,12 +516,14 @@ METHOD_TO_OPERATOR = {
"sym_max": sym_max,
"sym_min": sym_min,
"sym_not": sym_not,
"truediv": operator.truediv,
"float_truediv": operator.truediv,
"int_truediv": operator.truediv,
}
unary_magic_methods = {
"abs",
"sym_float",
"sym_int",
"ceil",
"floor",
"neg",
@ -559,20 +579,20 @@ also_bool_magic_methods = {"eq"}
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
# Methods that are only for float
only_float_magic_methods = {"is_integer"}
only_float_magic_methods = {"is_integer", "round", "sym_int"}
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
always_float_magic_methods = {"truediv", "sym_float", "pow"}
always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
for name in math_op_names:
sym_name = f"sym_{name}"
always_float_magic_methods.add(sym_name)
always_int_magic_methods = {"ceil", "floor", "trunc"}
always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
always_bool_magic_methods = {
"eq",
"ne",
@ -590,10 +610,16 @@ always_bool_magic_methods = {
# Methods that have a `__foo__` as well as `__rfoo__`
def _sympy_truediv(a, b):
from torch.utils._sympy.functions import TrueDiv
def _sympy_float_truediv(a, b):
from torch.utils._sympy.functions import FloatTrueDiv
return TrueDiv(a, b)
return FloatTrueDiv(a, b)
def _sympy_int_truediv(a, b):
from torch.utils._sympy.functions import IntTrueDiv
return IntTrueDiv(a, b)
def _sympy_floordiv(a, b):
@ -603,15 +629,24 @@ def _sympy_floordiv(a, b):
def _sympy_mod(a, b):
from torch.utils._sympy.functions import Mod
from torch.utils._sympy.functions import Mod, PythonMod
return Mod(a, b)
if a.is_nonnegative and b.is_nonnegative:
return Mod(a, b)
else:
return PythonMod(a, b)
def _sympy_pow(a, b):
from torch.utils._sympy.functions import Pow
def _sympy_pow_by_natural(a, b):
from torch.utils._sympy.functions import PowByNatural
return Pow(a, b)
return PowByNatural(a, b)
def _sympy_float_pow(a, b):
from torch.utils._sympy.functions import FloatPow
return FloatPow(a, b)
def _sympy_and(a, b):
@ -643,11 +678,13 @@ reflectable_magic_methods = {
"sub": operator.sub,
"mul": operator.mul,
"mod": _sympy_mod,
"pow": _sympy_pow,
"pow_by_natural": _sympy_pow_by_natural,
"float_pow": _sympy_float_pow,
"and": _sympy_and,
"or": _sympy_or,
"truediv": _sympy_truediv,
"floordiv": _sympy_floordiv,
"float_truediv": _sympy_float_truediv,
"int_truediv": _sympy_int_truediv,
"int_floordiv": _sympy_floordiv,
"lshift": _sympy_lshift,
"rshift": _sympy_rshift,
}
@ -672,21 +709,23 @@ def _floor_ceil_helper(a, fn):
def _sympy_floor(a):
import sympy
from torch.utils._sympy.functions import FloorToInt
return _floor_ceil_helper(a, sympy.floor)
return FloorToInt(a)
# NB: this is Python trunc semantics which returns an int. Do NOT use this to
# represent torch.trunc (which is float to float)
def _sympy_trunc(a):
from torch.utils._sympy.functions import Trunc
from torch.utils._sympy.functions import TruncToInt
return Trunc(a)
return TruncToInt(a)
def _sympy_ceil(a):
import sympy
from torch.utils._sympy.functions import CeilToInt
return _floor_ceil_helper(a, sympy.ceiling)
return CeilToInt(a)
def _sympy_eq(a, b):
@ -771,26 +810,28 @@ def _sympy_abs(a):
def _sympy_round(number, ndigits=None):
from torch.utils._sympy.functions import Round, RoundDecimal
from torch.utils._sympy.functions import RoundDecimal, RoundToInt
if ndigits is None:
return Round(number)
return RoundToInt(number)
else:
return RoundDecimal(number, ndigits)
def _sympy_sym_float(a):
# Cannot use sympy.Float(a) here, coz it expects python literals
# Multiply by 1.0 to cast to float. This is needed when the input
# is a SymInt which has the assumption that it is integer and
# SymPy will otherwise assume that return value cannot be a float.
return a * 1.0
from torch.utils._sympy.functions import ToFloat
# NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
# reports that it is an integer
return ToFloat(a)
def _sympy_is_integer(a):
import sympy
return sympy.Eq(sympy.floor(a), a)
from torch.utils._sympy.functions import ToFloat
return sympy.Eq(ToFloat(sympy.floor(a)), a)
magic_methods = {
@ -989,9 +1030,26 @@ def _make_node_magic(method, func):
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)
assert isinstance(other, SymNode)
# TODO: consider constant prop here
try:
out = func(self.expr, other.expr)
if method == "mod":
from torch.utils._sympy.functions import Mod, PythonMod
# Special handling for mod that requires access to the value
# ranges
shape_env = self.shape_env
if (
self.expr.is_nonnegative
or shape_env.bound_sympy(self.expr).lower >= 0
) and (
other.expr.is_nonnegative
or shape_env.bound_sympy(other.expr).lower >= 0
):
out = Mod(self.expr, other.expr)
else:
out = PythonMod(self.expr, other.expr)
else:
# TODO: consider constant prop here
out = func(self.expr, other.expr)
except Exception:
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
raise
@ -1122,9 +1180,13 @@ def _make_node_magic(method, func):
except Exception:
log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
raise
out = safe_expand(out)
pytype = int if ndigits is None else self.pytype
if ndigits is None:
pytype = int
else:
pytype = self.pytype
out_hint = None
if self.hint is not None:
@ -1136,6 +1198,7 @@ def _make_node_magic(method, func):
# hack down below works, because all round function down the line all take ndigits=None as default in their
# signature.
# TODO: Remove the args construction below if a different sentinel is used by FX.
# ezyang(May 2024): LOL
args = [self.fx_node]
if ndigits is not None:
args.append(ndigits)
@ -1259,6 +1322,32 @@ def _make_user_magic(method, user_type):
return x.node.is_constant()
return False
# Promotion rules for binary operations. NB: we preserve PYTHON semantics
# - if args are same type, do nothing
# - if one arg is float, promote other arg to float
# - nb: this applies to floordiv, even though output is integral
# (it's still float)
# - pow is funny business
# - if both ints
# - trigger a guard on exponent >= 0
# - if non-negative, output is int
# - otherwise, output is float
# - otherwise, promote other arg to float
# - nb: complex is impossible to handle correctly lol, with
# negative base and integral float need to diverge semantics and
# just always return complex. Neener neener pretend this problem
# doesn't exist
# - equality is pain: Python does the fancy thing where it unpacks the
# mantissa from the float and then compares that against the int.
# Which means it is able to tell that
# 9007199254740993 != 9007199254740992. (rather than if the LHS was
# promoted to float, in which case it would have truncated to the RHS
# and subsequently been equal). We'll model this exactly by having
# special mixed type equality operations. Unfortunately, we need to
# do this for all comparison operations (maybe I'll only implement
# compare)
# - sym_ite mumble mumble really shouldn't allow mixed but whatever
if method in bool_becomes_int_magic_methods:
def promote(x):
@ -1272,6 +1361,41 @@ def _make_user_magic(method, user_type):
def promote(x):
return x
def promote2(self, other):
# TODO: Remove eq and other relations from this list.
# CPython has fancy implementations for these to get as much precision
# as possible instead of just promoting to float64 and praying, so we
# need to handle them specially too.
# Also, note that int_truediv doesn't go through this path: both
# arguments are "int" so there isn't any promotion
if method not in [
"add",
"sub",
"mul",
"mod",
"float_pow",
"float_truediv",
"int_floordiv",
"sym_min",
"sym_max",
# TODO: remove these
"eq",
"ne",
"gt",
"lt",
"le",
"ge",
]:
return self, other
f_self = isinstance(self, (float, torch.SymFloat))
f_other = isinstance(other, (float, torch.SymFloat))
if f_self or f_other:
if not f_self:
self = torch.sym_float(self)
if not f_other:
other = torch.sym_float(other)
return self, other
# Before and after performing the operation, check if any operands are constant.
# If so, extract out the constant values first. If `self` itself is a
# constant, then "redispatch" by calling back into the operator. Sometimes
@ -1286,9 +1410,12 @@ def _make_user_magic(method, user_type):
return wrap_node(getattr(self.node, method_attr)())
def binary_magic_impl(self, other):
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
return NotImplemented
sym_node_log.debug("MAGIC %s %s %s", method, self, other)
self = promote(self)
other = promote(other)
self, other = promote2(self, other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
@ -1300,8 +1427,11 @@ def _make_user_magic(method, user_type):
return get_constant(ret) if is_constant(ret) else ret
def rbinary_magic_impl(self, other):
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
return NotImplemented
self = promote(self)
other = promote(other)
self, other = promote2(self, other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):

View File

@ -61,7 +61,7 @@ from torch._logging import trace_structured, structured
from torch import SymBool, SymFloat, SymInt
from torch._guards import ShapeGuard, Source, TracingContext
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator
from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.singleton_int import SingletonInt
@ -869,9 +869,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
for N=1.
"""
if min is None:
min = -sympy.oo
min = -sys.maxsize - 1
if max is None:
max = sympy.oo
max = sys.maxsize - 1
if max < min:
raise ValueError(
@ -979,16 +979,6 @@ def eval_guards(gm, *args, ignore_static=True):
def bind_symbols(gm, *args):
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges):
"""
We assert that the bounds are either Boolean, or not finite, or can be computed
in exact prevision via rational arithmetic.
The only exception to this is the rare case when the user calls `sqrt(s0)`
sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still)
"""
assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr)
assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr)
class DimDynamic(Enum):
"""
Controls how to perform symbol allocation for a dimension. It is always
@ -1387,14 +1377,19 @@ SYMPY_INTERP = {
'Min': min,
'Max': max,
'Mod': operator.mod,
'PythonMod': operator.mod,
'FloorDiv': operator.floordiv,
'TrueDiv': operator.truediv,
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
'floor': math.floor,
'ceiling': math.ceil,
'FloorToInt': math.floor,
'CeilToInt': math.ceil,
'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
'Round': builtins.round,
'RoundToInt': builtins.round,
'RoundDecimal': builtins.round,
'TruncToInt': math.trunc,
'IntTrueDiv': operator.truediv,
}
@ -1642,10 +1637,17 @@ class DimConstraints:
congruence = (base - mod_reduced) % divisor
if congruence != 0:
self._congruences[s].add(congruence)
# NB: Must not be CleanDiv, it needs to be regular sympy division
# so inequality solver works. This is sort of problematic for
# is_integer tests though haha
return (base - mod_reduced) / divisor
if expr.has(Mod):
expr = expr.replace(Mod, mod_handler)
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
# arguments should be OK.
if expr.has(PythonMod):
expr = expr.replace(PythonMod, mod_handler)
if expr.has(FloorDiv):
expr = expr.replace(FloorDiv, floor_div_handler)
return expr
@ -3330,6 +3332,7 @@ class ShapeEnv:
self.pending_fresh_unbacked_symbols.append(symbol)
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
vr = self.var_to_range[symbol] = ValueRanges.unknown()
assert vr.is_float
# Create a new FX placeholder and Z3 variable for 'symbol'.
fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
@ -3348,6 +3351,7 @@ class ShapeEnv:
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
vr = self.var_to_range[symbol] = self._default_unspecified_value_range()
assert vr.is_int
# Create a new FX placeholder and Z3 variable for 'symbol'.
fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
@ -3371,6 +3375,7 @@ class ShapeEnv:
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
vr = self.var_to_range[symbol] = ValueRanges(0, 1)
assert vr.is_int
# Create a new FX placeholder and Z3 variable for 'symbol'.
fx_node = self._create_fx_placeholder_and_z3var(symbol, bool)
@ -3516,6 +3521,7 @@ class ShapeEnv:
self.var_to_range[sympy_expr] &= constraint_dim.vr
vr = self.var_to_range[sympy_expr]
assert vr.is_int
if val not in vr:
raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")
@ -3524,6 +3530,7 @@ class ShapeEnv:
elif isinstance(val, float):
self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo)
range_str = f"[{vr.lower}, {vr.upper}]"
assert vr.is_float
else:
# Skip var_range logic for SingletonInt
# Only used for jagged layout nested tensors
@ -3573,6 +3580,7 @@ class ShapeEnv:
def add_var_to_val(self, expr: sympy.Symbol, val: int):
""" Adds a new symbol to the symbolic environment. """
log.debug("add_var_to_val %s %s", expr, val, stack_info=True)
assert expr not in self.var_to_val, f"{expr} already exists"
self.var_to_val[expr] = sympy.Integer(val)
@ -4301,7 +4309,8 @@ class ShapeEnv:
# Clamp values of size-like variables
for x in self.size_like & var_to_range.keys():
if var_to_range[x] is not None:
var_to_range[x] = ValueRanges(2, sympy.oo)
var_to_range[x] = ValueRanges(2, sys.maxsize - 1)
assert var_to_range[x].is_int
return bound_sympy(expr, var_to_range)
@_lru_cache
@ -4418,6 +4427,11 @@ class ShapeEnv:
vr = self._default_unspecified_value_range()
if size_oblivious and k in self.size_like:
lower = max(2, vr.lower)
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems.
if lower <= vr.upper:
vr = ValueRanges(lower, vr.upper)
else:
lower = vr.lower
# Don't do anything if we don't have a nontrivial lower bound
@ -4425,10 +4439,17 @@ class ShapeEnv:
# SymInt
if (
lower < (-sys.maxsize - 1) // 2 or
(unbacked_only and k in self.var_to_val)
(unbacked_only and k in self.var_to_val) or
not vr.is_int
):
new_range_env[k] = vr
continue
# The goal is to take our symbols which have various lower bounds
# and reallocate them into new symbols which are exactly positive;
# e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
# [1, inf], where s0 = ess0 + 1. This gives the most information
# to sympy for subsequent simplifications.
#
# Positive means >= 1
# Positive - 1 means >= 0
# Positive + lower - 1 means >= lower
@ -4460,6 +4481,14 @@ class ShapeEnv:
self.counter["sympy_recursion_error"] += 1
return None
new_expr = safe_expand(new_expr)
if new_expr.is_number:
return new_expr
# This is bad to do, the replacement with division leaves us with
# rationals when atom.args[0] is addition, e.g., sympy will happily
# turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication!
"""
floor_div_replace = {}
for atom in new_expr.atoms(FloorDiv):
floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
@ -4468,13 +4497,12 @@ class ShapeEnv:
# are still free symbols
if new_expr.is_number:
return new_expr
"""
# Check if the range can solve it statically
out = bound_sympy(new_expr, new_range_env)
if expect_rational:
_assert_bound_is_rational(new_expr, out)
if out.is_singleton():
return out.lower
if out.is_singleton():
return out.lower
return new_expr if unbacked_only else None
@ -4526,7 +4554,7 @@ class ShapeEnv:
for fd in expr.atoms(FloorDiv):
base, divisor = fd.args
if self.replace(Mod(base, divisor)) in self.divisible:
div_replacements[fd] = base / divisor
div_replacements[fd] = CleanDiv(base, divisor)
new_expr = expr.xreplace(div_replacements)
new_expr = safe_expand(new_expr)
new_pows = new_expr.atoms(sympy.Pow)
@ -4670,7 +4698,10 @@ class ShapeEnv:
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
def issubset(x, y):
return (x & int_range).issubset(y & int_range)
if x.is_int and y.is_int:
return (x & int_range).issubset(y & int_range)
else:
return x.issubset(y)
# First, refine the value range of a based on the computed value range
# of tgt. This is always OK to do, even if we decide not to do the
@ -4688,7 +4719,7 @@ class ShapeEnv:
b = next(iter(tgt.free_symbols))
# Try to invert the equality
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
if r is not None:
if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])):
b_bound = self.bound_sympy(r[1])
self.var_to_range[b] = b_bound & self.var_to_range[b]
tgt_bound = self.bound_sympy(tgt)
@ -4899,12 +4930,12 @@ class ShapeEnv:
):
# We have Mod(i0, q / c) == 0, which means we can
# rewrite i0 as (q / gcd(q, c)) * i1
d = q / sympy.gcd(q, c)
d = q / sympy.gcd(q, c) # TODO: CleanDiv?
i1 = self.create_unbacked_symint().node.expr
# Propagate the value ranges. It doesn't really
# matter if we use truediv or floordiv, because we
# have established divisibility.
self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv(
self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv(
self.var_to_range[i0], ValueRanges.wrap(d)
))
# Propagate size-like-ness
@ -5341,7 +5372,6 @@ class ShapeEnv:
lower, upper = vr.lower, vr.upper
rhs_vr = bound_sympy(rhs, self.var_to_range)
_assert_bound_is_rational(rhs, rhs_vr)
# Let's suppose that we have a preexisting range for x [0, 100].
# Now, we issue a guard x > y, where the range for y is [50, 150].

View File

@ -216,10 +216,7 @@ try:
def abs(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.Abs(number)
def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef:
if ndigits is not None:
raise ValueError("round(..., ndigits=) is currently not supported by shape validations.")
def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef:
# Pythons builtin 'round' implements the 'round half to even' strategy
# See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
# z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to
@ -284,7 +281,7 @@ try:
operator.truediv: lift(ops.div),
operator.mod: lift(ops.mod),
operator.abs: lift(ops.abs),
builtins.round: lift(ops.round),
builtins.round: lift(ops.round_to_int),
# Math module.
math.ceil: lift(ops.ceil),
@ -350,6 +347,7 @@ try:
self._ops = _Z3Ops(self._validator)
def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
# TODO: Probably OK to relax this and allow lower precision
if dtype is torch.int64:
return z3.IntVal(int(value))
if dtype is torch.double:
@ -358,6 +356,20 @@ try:
return z3.BoolVal(bool(value))
raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
if dtype == torch.float64:
return z3.ToReal(x)
raise NotImplementedError(f"to_dtype {dtype} NYI")
def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return z3.ToInt(x)
def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return self._ops.round_to_int(x)
def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
return self._ops.div(numerator, denominator)
def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
return self._ops.div(numerator, denominator)
@ -370,11 +382,17 @@ try:
def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
return self._ops.pow(base, exp)
def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
return self._ops.pow(base, exp)
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
return self._ops.mod(p, q)
def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef:
return self._ops.round(number, ndigits)
def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return self._ops.ceil(x)
def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return self._ops.floor(x)
def __getattr__(self, name: str) -> Any:
REPLACEMENT = {

View File

@ -1,43 +1,78 @@
import functools
import math
import sys
import sympy
from sympy import S
from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or
__all__ = [
"FloorDiv",
"ModularIndexing",
"CleanDiv",
"CeilDiv",
"Pow",
"TrueDiv",
"IntTrueDiv",
"FloatTrueDiv",
"LShift",
"RShift",
"IsNonOverlappingAndDenseIndicator",
"Round",
"RoundToInt",
"RoundDecimal",
"ToFloat",
"FloatPow",
"PowByNatural",
]
def _keep_float(f):
@functools.wraps(f)
def inner(*args):
r = f(*args)
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
r, sympy.Float
):
r = sympy.Float(float(r))
return r
return inner
def fuzzy_eq(x, y):
if None in (x, y):
return None
return x == y
# It would be nice to have assertions on whether or not inputs is_integer
# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
# sometimes inconsistently reports floats an integers.
#
# What we can assume from sympy is that if something is an int, it
# definitely is is_integer, but if it is a float it may or may not
# be is_integer. So we are unable to do strong asserts that things
# are NOT integers.
# TODO: In Triton, // rounds to zero, but in Python, it is floor division.
# When we can prove both arguments are non-negative, we should just have a
# GenericFloorDiv (name pending) which can codegen efficiently in Python/C,
# and then PythonFloorDiv and CIntDiv which have the appropriate rounding
# semantics.
#
# Right now, FloorDiv de facto changes behavior if arguments are negative or
# not, this can potentially cause correctness issues.
class FloorDiv(sympy.Function):
"""
We maintain this so that:
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
NB: This is Python-style floor division, round to -Inf
"""
nargs = (2,)
precedence = 50 # precedence of mul # noqa: F811
# Default return type for SymPy assumptions.
# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers
is_real = True
is_integer = True
@property
def base(self):
@ -52,29 +87,14 @@ class FloorDiv(sympy.Function):
divisor = printer.parenthesize(self.divisor, self.precedence)
return f"({base}//{divisor})"
# SymPy assumptions based on argument types.
def _eval_is_real(self):
return fuzzy_or([self.base.is_real, self.divisor.is_real])
def _eval_is_integer(self):
return fuzzy_and([self.base.is_integer, self.divisor.is_integer])
# Automatic evaluation.
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
@classmethod
def eval(cls, base, divisor):
def check_supported_type(x):
if (
x.is_integer is False and x.is_real is False and x.is_complex
) or x.is_Boolean:
raise TypeError(
f"unsupported operand type(s) for //: "
f"'{type(base).__name__}' and '{type(divisor).__name__}'"
f", expected integer or real"
)
check_supported_type(base)
check_supported_type(divisor)
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
# Assert triggered by inequality solver
# assert base.is_integer, base
# assert divisor.is_integer, divisor
# We don't provide the same error message as in Python because SymPy
# makes it difficult to check the types.
@ -85,26 +105,22 @@ class FloorDiv(sympy.Function):
return sympy.S.Zero
if base.is_integer and divisor == 1:
return base
if base.is_real and divisor == 1:
return sympy.floor(base)
if base.is_integer and divisor == -1:
return sympy.Mul(base, -1)
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return base // divisor
if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(
divisor, (sympy.Integer, sympy.Float)
):
return sympy.floor(base / divisor)
return sympy.Integer(int(base) // int(divisor))
if isinstance(base, FloorDiv):
return FloorDiv(base.args[0], base.args[1] * divisor)
if isinstance(divisor, sympy.Rational) and divisor.p == 1:
return sympy.floor(base * divisor.q)
# gcd in sympy is over polynomials, so you'll end up with rationals if
# you do this. Don't.
"""
if isinstance(base, sympy.Add):
for a in base.args:
gcd = sympy.gcd(a, divisor)
if gcd == divisor:
return FloorDiv(base - a, divisor) + a / gcd
"""
try:
gcd = sympy.gcd(base, divisor)
@ -189,6 +205,19 @@ class Where(sympy.Function):
nargs = (3,)
def _eval_is_integer(self):
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
def _eval_is_nonnegative(self):
return (
True
if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
else None
)
def _eval_is_positive(self):
return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
@classmethod
def eval(cls, c, p, q):
if c == sympy.true:
@ -197,28 +226,27 @@ class Where(sympy.Function):
return q
class Mod(sympy.Function):
"""
We maintain this so that we avoid SymPy correctness issues, such as:
https://github.com/sympy/sympy/issues/25146
"""
# Python-style modulus: take sign from RHS
class PythonMod(sympy.Function):
nargs = (2,)
is_integer = True
@classmethod
def eval(cls, p, q):
# This was adapted from: sympy/core/mod.py
# python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
# Triggered by sympy.solvers.inequalities.reduce_inequalities
# assert p.is_integer, p
# assert q.is_integer, q
if q.is_zero:
raise ZeroDivisionError("Modulo by zero")
# If either of them is NaN or infinite.
if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False:
return S.NaN
# Three cases:
# 1. p == 0
# 2. p is either q or -q
# 3. p is integer and q == 1
if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1):
if p is S.Zero or p in (q, -q) or q == 1:
return S.Zero
# Evaluate if they are both literals.
@ -247,10 +275,7 @@ class Mod(sympy.Function):
if sympy.Mod(p, q) == 0:
return S.Zero
def _eval_is_integer(self):
p, q = self.args
return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined]
# NB: args[1] for PythonMod
def _eval_is_nonnegative(self):
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
@ -258,6 +283,58 @@ class Mod(sympy.Function):
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
# Generic modulus: only defined on non-negative arguments
class Mod(sympy.Function):
nargs = (2,)
is_integer = True
is_nonnegative = True
@classmethod
def eval(cls, p, q):
# This was adapted from: sympy/core/mod.py
# Triggered by
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
# assert p.is_integer, p
# assert q.is_integer, q
if q.is_zero:
raise ZeroDivisionError("Modulo by zero")
# Three cases:
# 1. p == 0
# 2. p is either q or -q
# 3. p is integer and q == 1
if p is S.Zero or p in (q, -q) or q == 1:
return S.Zero
# Evaluate if they are both literals.
if q.is_Number and p.is_Number:
assert p >= 0, p
assert q >= 1, q
return p % q
# If q == 2, it's a matter of whether p is odd or even.
if q.is_Number and q == 2:
if p.is_even:
return S.Zero
if p.is_odd:
return S.One
# If p is a multiple of q.
r = p / q
if r.is_integer:
return S.Zero
# If p < q and its ratio is positive, then:
# - floor(p / q) = 0
# - p % q = p - floor(p / q) * q = p
less = p < q
if less.is_Boolean and bool(less) and r.is_positive:
return p
class CleanDiv(FloorDiv):
"""
Div where we can assume no rounding.
@ -267,6 +344,36 @@ class CleanDiv(FloorDiv):
pass
# Don't use sympy ceiling/floor as they will attempt simplifications involving
# frac
class CeilToInt(sympy.Function):
is_integer = True
@classmethod
def eval(cls, number):
# assert number.is_integer is not True, number
if number == sympy.oo:
return sympy.Integer(sys.maxsize - 1)
if number == -sympy.oo:
return sympy.Integer(-sys.maxsize - 1)
if isinstance(number, sympy.Number):
return sympy.Integer(math.ceil(float(number)))
class FloorToInt(sympy.Function):
is_integer = True
@classmethod
def eval(cls, number):
# assert number.is_integer is not True, number
if number == sympy.oo:
return sympy.Integer(sys.maxsize - 1)
if number == -sympy.oo:
return sympy.Integer(-sys.maxsize - 1)
if isinstance(number, sympy.Number):
return sympy.Integer(math.floor(float(number)))
class CeilDiv(sympy.Function):
"""
Div used in indexing that rounds up.
@ -275,6 +382,8 @@ class CeilDiv(sympy.Function):
is_integer = True
def __new__(cls, base, divisor):
base = sympy.sympify(base)
divisor = sympy.sympify(divisor)
if sympy.gcd(base, divisor) == divisor:
return CleanDiv(base, divisor)
else:
@ -282,6 +391,8 @@ class CeilDiv(sympy.Function):
class LShift(sympy.Function):
is_integer = True
@classmethod
def eval(cls, base, shift):
if shift < 0:
@ -290,6 +401,8 @@ class LShift(sympy.Function):
class RShift(sympy.Function):
is_integer = True
@classmethod
def eval(cls, base, shift):
if shift < 0:
@ -297,28 +410,107 @@ class RShift(sympy.Function):
return base // 2**shift
# Overloaded to be compatible with regular Python.
# https://github.com/pytorch/pytorch/issues/90900
class Pow(sympy.Function):
def safe_pow(base, exp):
sign = 1
if base < 0:
base = -base
sign = 1 if exp % 2 == 0 else -1
return sign * _safe_pow(base, exp)
def _safe_pow(base, exponent):
if exponent < 0:
raise ValueError("Exponent must be non-negative.")
if exponent == 0:
return 1
half_exp = safe_pow(base, exponent // 2)
if half_exp > sys.maxsize - 1:
return sys.maxsize - 1
result = half_exp * half_exp
if result > sys.maxsize - 1:
return sys.maxsize - 1
if exponent % 2 == 1:
result *= base
if result > sys.maxsize - 1:
return sys.maxsize - 1
return result
class PowByNatural(sympy.Function):
is_integer = True
@classmethod
def eval(cls, base, exp):
if exp.is_zero:
return sympy.Integer(1)
elif base.is_zero and exp < 0:
raise ZeroDivisionError(f"{base} cannot be raised to a negative power")
else:
return base**exp
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
return sympy.Integer(safe_pow(base, exp))
if isinstance(exp, sympy.Integer):
# Translate power into iterated multiplication
r = sympy.Integer(1)
for _ in range(int(exp)):
r *= base
return r
# NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
# is a natural number if we do
# base is assumed to be nonnegative, thereby prevent complex numbers from
# occuring
class FloatPow(sympy.Function):
is_integer = False
is_real = True
@classmethod
def eval(cls, base, exp):
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
return sympy.Float(float(base) ** float(exp))
# NB: do not do any nontrivial reasoning
# Overloaded to be compatible with regular Python.
# https://github.com/pytorch/pytorch/issues/90900
class TrueDiv(sympy.Function):
#
# In particular, sympy division is willing to simplify x/x == 1
# where 1 is an integer, but this must be a float if x was float.
class FloatTrueDiv(sympy.Function):
is_integer = False
is_real = True
@classmethod
def eval(cls, base, divisor):
# assert base.is_integer is not True, base
# assert divisor.is_integer is not True, divisor
if divisor.is_zero:
raise ZeroDivisionError("division by zero")
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
return sympy.Float(float(base) / float(divisor))
# Overloaded to be compatible with regular Python. We distinguish this from
# FloatTrueDiv, because the code generation has to be different for this case:
# Python has a fancy algorithm for integer true division that isn't just
# "promote both arguments to float and use float division", so you need to
# codegen it differently. While technically you can work it out from the
# types of the input, this is often inconvenient to do in Inductor codegen,
# so just have a different operator
# NB: Right now, Inductor codegen doesn't implement this correctly lol
class IntTrueDiv(sympy.Function):
is_integer = False
is_real = True
@classmethod
def eval(cls, base, divisor):
if divisor.is_zero:
raise ZeroDivisionError("division by zero")
else:
return base / divisor
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
return sympy.Float(int(base) / int(divisor))
# TODO: As an indicator, this != 0 implies == 1 (and vice versa).
@ -353,45 +545,85 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function):
return None
class Trunc(sympy.Function):
# NB: this is inconsistent with math.trunc in Python
class TruncToFloat(sympy.Function):
is_integer = False
is_real = True
@classmethod
def eval(cls, number):
# assert number.is_integer is not True, number
if isinstance(number, sympy.Number):
# NB: It is safe to use truncation to integer, which is what
# math.trunc does, as Python integers are arbitrary precision and
# so we are guaranteed not to lose precision when we do this
return sympy.Float(math.trunc(float(number)))
class TruncToInt(sympy.Function):
is_integer = True
@classmethod
def eval(cls, number):
if number.is_integer:
return number
elif isinstance(number, sympy.Number):
# assert number.is_integer is not True, number
if number == sympy.oo:
return sympy.Integer(sys.maxsize - 1)
if number == -sympy.oo:
return sympy.Integer(-sys.maxsize - 1)
if isinstance(number, sympy.Number):
return sympy.Integer(math.trunc(float(number)))
class Round(sympy.Function):
# This is float -> int
class RoundToInt(sympy.Function):
is_integer = True
@classmethod
def eval(cls, number):
if number.is_integer:
return number
elif isinstance(number, sympy.Number):
return sympy.Integer(round(float(number)))
# assert number.is_integer is not True, number
def __int__(self):
# This will only ever be called when computing size hints. At that point, self.args[0] should be a number and
# no longer an expression. If it were, the float call would fail and the caller would handle this further.
return round(float(self.args[0])) # type: ignore[arg-type]
if isinstance(number, sympy.Float):
return sympy.Integer(round(float(number), 0))
# To get float -> int, Python style round semantics.
#
# x = PyFloat_AsDouble(self);
# if (o_ndigits == Py_None) {
# /* single-argument round or with None ndigits:
# * round to nearest integer */
# rounded = round(x);
# if (fabs(x-rounded) == 0.5)
# /* halfway case: round to even */
# rounded = 2.0*round(x/2.0);
# return PyLong_FromDouble(rounded);
# }
# NB: Like Round, this only ever returns floats. ndigits cannot be None
class RoundDecimal(sympy.Function):
is_integer = False
is_real = True
@classmethod
def eval(cls, number, ndigits):
if number.is_integer and ndigits >= 0:
# assert number.is_integer is not True, number
if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer):
return sympy.Float(round(float(number), int(ndigits)))
class ToFloat(sympy.Function):
is_integer = False
is_real = True
@classmethod
def eval(cls, number):
if number in [sympy.oo, -sympy.oo]:
return number
elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
value_type, output_type = (
(int, sympy.Integer)
if isinstance(number, sympy.Integer)
else (float, sympy.Float)
)
return output_type(round(value_type(number), int(ndigits)))
if isinstance(number, sympy.Integer):
return sympy.Float(int(number))
def make_opaque_unary_fn(name):

View File

@ -15,16 +15,23 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch
from .functions import (
CeilToInt,
CleanDiv,
FloatPow,
FloatTrueDiv,
FloorDiv,
FloorToInt,
IntTrueDiv,
IsNonOverlappingAndDenseIndicator,
Mod,
ModularIndexing,
Pow,
Round,
PowByNatural,
PythonMod,
RoundDecimal,
TrueDiv,
Trunc,
RoundToInt,
ToFloat,
TruncToFloat,
TruncToInt,
Where,
)
@ -49,30 +56,39 @@ def handlers():
sympy.Le: "le",
sympy.Ge: "ge",
sympy.Not: "not_",
TrueDiv: "truediv",
IntTrueDiv: "int_truediv",
FloatTrueDiv: "truediv",
FloorDiv: "floordiv",
CleanDiv: "div",
Trunc: "trunc",
CleanDiv: "floordiv", # TODO: hmm?
TruncToFloat: "trunc",
Where: "where",
sympy.Add: "add",
sympy.Mul: "mul",
Pow: "pow",
sympy.Pow: "pow",
FloatPow: "pow",
PowByNatural: "pow_by_natural",
# sympy simplifies x * x into Pow(x, 2), so we need to handle this.
# Do NOT use builtin Pow for floats
# TODO: There is a hazard here, if we have float * float it will
# also get turned into Pow(float, 2) but we don't want this because
# pow_by_natural is assumed to only be integers. Probably the fix is
# to add a FloatMul to impede this optimization
sympy.Pow: "pow_by_natural",
Mod: "mod",
PythonMod: "mod", # TODO: this is wrong
# TODO: Inductor can generate these, but it's ill-specified which
# semantics were intended here. Needs to be cleaned up along with
# FloorDiv in a bigger cleanup
sympy.Mod: "mod",
sympy.Abs: "abs",
sympy.log: "log",
sympy.exp: "exp",
sympy.floor: "floor",
sympy.ceiling: "ceil",
sympy.Min: "minimum",
sympy.Max: "maximum",
ModularIndexing: "modular_indexing",
sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
sympy.Piecewise: "piecewise",
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
Round: "round",
RoundDecimal: "round",
RoundDecimal: "round_decimal",
}
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
HANDLERS[getattr(sympy, name)] = name
@ -84,7 +100,11 @@ ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
def sympy_interp(
analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
analysis,
env: Dict[sympy.Symbol, Any],
expr: Union[sympy.Expr, SympyBoolean],
*,
index_dtype=torch.int64,
):
# Handle base cases
dtype = None
@ -105,9 +125,32 @@ def sympy_interp(
expr.args[1], sympy.core.numbers.Half
):
return analysis.sqrt(sympy_interp(analysis, env, expr.args[0]))
if isinstance(expr, ToFloat):
return analysis.to_dtype(
sympy_interp(analysis, env, expr.args[0]), torch.float64
)
# Recursive case
args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type]
# These handlers are special because they take an extra dtype argument
# specifying what they should convert to, and we need to appropriately set
# this up when we convert from Sympy. A reasonable default when you
# are translating is to conservatively do int64, and then narrow these
# arguments later when you discover you can narrow the index range. But
# if you already know that 32-bit indexing is OK, you can directly do the
# sympy translation with index_dtype=torch.int32
INDEX_DTYPE_HANDLERS = {
TruncToInt: "trunc_to_int",
sympy.floor: "floor_to_int",
sympy.ceiling: "ceil_to_int",
FloorToInt: "floor_to_int",
CeilToInt: "ceil_to_int",
RoundToInt: "round_to_int",
}
if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
return getattr(analysis, handler_name)(*args, index_dtype)
if hasattr(expr.func, "_torch_handler_name"):
handler_name = expr.func._torch_handler_name
else:

View File

@ -1,12 +1,25 @@
import math
import operator
import sympy
import torch
from torch.utils._sympy.functions import (
_keep_float,
FloatPow,
FloatTrueDiv,
FloorDiv,
IntTrueDiv,
Mod,
OpaqueUnaryFn_exp,
OpaqueUnaryFn_log,
OpaqueUnaryFn_sqrt,
PowByNatural,
RoundDecimal,
RoundToInt,
ToFloat,
TruncToInt,
)
@ -62,20 +75,41 @@ class ReferenceAnalysis:
@staticmethod
def reciprocal(x):
return 1 / x
return FloatTrueDiv(1.0, x)
@staticmethod
def square(x):
return x * x
return PowByNatural(x, 2)
@staticmethod
def trunc_to_int(x, dtype):
return TruncToInt(x)
@staticmethod
def ceil_to_int(x, dtype):
return sympy.ceiling(x)
@staticmethod
def floor_to_int(x, dtype):
return sympy.floor(x)
@staticmethod
def floor(x):
return _keep_float(sympy.floor)(x)
@staticmethod
def ceil(x):
return _keep_float(sympy.ceiling)(x)
@staticmethod
def to_dtype(x, dtype):
if dtype == torch.float64:
return ToFloat(x)
raise NotImplementedError(f"to_dtype {dtype} NYI")
@staticmethod
def mod(x, y):
ret = abs(x) % abs(y)
# without check:
# tracing will fail trying to go through control-flow if x is Proxy()
if isinstance(x, (int, sympy.Number)) and x < 0:
ret *= -1
return ret
return Mod(x, y)
@staticmethod
def abs(x):
@ -87,37 +121,31 @@ class ReferenceAnalysis:
@staticmethod
def truediv(a, b):
return a / b
return FloatTrueDiv(a, b)
@staticmethod
def div(a, b):
return ReferenceAnalysis.truediv(a, b)
def int_truediv(a, b):
return IntTrueDiv(a, b)
@staticmethod
def floordiv(a, b):
if b == 0:
return sympy.nan if a == 0 else sympy.zoo
return a // b
return FloorDiv(a, b)
@staticmethod
def truncdiv(a, b):
result = a / b
if result.is_finite:
result = sympy.Integer(result)
return result
raise NotImplementedError("TODO: truncdiv")
@staticmethod
def add(a, b):
return a + b
return _keep_float(operator.add)(a, b)
@staticmethod
def mul(a, b):
return a * b
return _keep_float(operator.mul)(a, b)
@staticmethod
def sub(a, b):
return a - b
return _keep_float(operator.sub)(a, b)
@staticmethod
def exp(x):
@ -133,39 +161,27 @@ class ReferenceAnalysis:
@staticmethod
def pow(a, b):
return a**b
return _keep_float(FloatPow)(a, b)
@staticmethod
def pow_by_natural(a, b):
return PowByNatural(a, b)
@staticmethod
def minimum(a, b):
# Poorman's version of upcasting in Sympy
# This won't do for sympy.Expr as the casting does nothing for those
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
result_type = sympy.Float
else:
assert a.is_Integer
assert b.is_Integer
result_type = sympy.Integer
return sympy.Min(result_type(a), result_type(b))
return sympy.Min(a, b)
@staticmethod
def maximum(a, b):
# Poorman's version of upcasting in Sympy
# This won't do for sympy.Expr as the casting does nothing for those
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
result_type = sympy.Float
else:
assert a.is_Integer
assert b.is_Integer
result_type = sympy.Integer
return sympy.Max(result_type(a), result_type(b))
return sympy.Max(a, b)
@staticmethod
def floor(x):
return sympy.floor(x)
def round_to_int(a, dtype):
return RoundToInt(a)
@staticmethod
def ceil(x):
return sympy.ceiling(x)
def round_decimal(a, b):
return RoundDecimal(a, b)
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
@ -191,10 +207,20 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
def floordiv(a, b):
return a // b
@staticmethod
def mod(x, y):
return x % y
@staticmethod
def truncdiv(a, b):
return a / b
@staticmethod
def to_dtype(x, dtype):
if dtype == torch.float64:
return float(x)
raise NotImplementedError(f"to_dtype {dtype} NYI")
@staticmethod
def exp(x):
raise AssertionError("exp is not valid shape sympy expr")
@ -216,9 +242,40 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
return torch.sym_max(a, b)
@staticmethod
def floor(x):
def floor_to_int(x, dtype):
return math.floor(x)
@staticmethod
def ceil(x):
def ceil_to_int(x, dtype):
return math.ceil(x)
@staticmethod
def floor(x):
return float(math.floor(x))
@staticmethod
def ceil(x):
return float(math.ceil(x))
@staticmethod
def truediv(a, b):
return a / b
@staticmethod
def pow(a, b):
return a**b
@staticmethod
def pow_by_natural(a, b):
# Pray that safe_pow is not needed here lol. In particular, this
# never participates in VR low/high ranges, so overflow should be
# unlikely
return a**b
@staticmethod
def round_to_int(a, dtype):
return round(a)
@staticmethod
def round_decimal(a, b):
return round(a, ndigits=b)

View File

@ -88,6 +88,7 @@ def try_solve(
# Return if we were able to isolate 'thing' on the left-hand side.
if isinstance(e, sympy.Rel) and e.lhs == thing:
log.debug("solved: %s ---> %s", expr, e)
return e, e.rhs
return None

View File

@ -5,6 +5,7 @@ import itertools
import logging
import math
import operator
import sys
from typing import (
Callable,
Dict,
@ -25,17 +26,20 @@ import torch
from torch._prims_common import dtype_to_type
from .functions import (
OpaqueUnaryFn_acos,
OpaqueUnaryFn_asinh,
OpaqueUnaryFn_atan,
OpaqueUnaryFn_cosh,
_keep_float,
FloatTrueDiv,
FloorDiv,
IntTrueDiv,
OpaqueUnaryFn_exp,
OpaqueUnaryFn_log,
OpaqueUnaryFn_sinh,
OpaqueUnaryFn_sqrt,
OpaqueUnaryFn_tanh,
Round,
PowByNatural,
RoundDecimal,
RoundToInt,
safe_pow,
ToFloat,
TruncToFloat,
TruncToInt,
)
from .interp import sympy_interp
@ -120,6 +124,8 @@ class ValueRanges(Generic[_T]):
lower: _T
upper: _T
is_bool: bool
is_int: bool
is_float: bool
@overload
def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None:
@ -142,8 +148,39 @@ class ValueRanges(Generic[_T]):
# Because this is a frozen class
object.__setattr__(self, "lower", lower)
object.__setattr__(self, "upper", upper)
# Unlike bool/int in Python, we don't report bools are ints
object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean))
assert isinstance(upper, SympyBoolean) == self.is_bool
if self.is_bool:
assert isinstance(upper, SympyBoolean), (lower, upper)
# Warning: is_int/is_float is best effort. We do pretty well in
# Dynamo, but in Inductor these attributes are often wrong because we
# are not very rigorous in dtype analysis. This is also why we need
# the flexible analysis for is_int: sometimes a sympy.oo pops in for
# an integer bound. I would /like/ for us not to do this, but it's
# too hard to push the invariant through right now.
object.__setattr__(
self,
"is_int",
not self.is_bool
and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)),
)
"""
# This assert is just impossible right now, too many sympy bugs
if self.is_int:
# NB: sympy will sometimes randomly lose the float-ness of zero,
# so we also need to account for that in the assertion here.
# See also https://github.com/sympy/sympy/issues/26620
assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], (
lower,
upper,
)
assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper)
"""
# NB: [-oo, oo] always advertises as float!
object.__setattr__(self, "is_float", not self.is_bool and not self.is_int)
assert self.is_bool or self.is_int or self.is_float, (lower, upper)
def boolify(self) -> ValueRanges[SympyBoolean]:
if vr_is_bool(self):
@ -184,6 +221,8 @@ class ValueRanges(Generic[_T]):
if self == ValueRanges.unknown():
return other
assert self.is_bool == other.is_bool, (self, other)
assert self.is_int == other.is_int, (self, other)
assert self.is_float == other.is_float, (self, other)
if self.is_bool:
return ValueRanges(
sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)
@ -353,7 +392,12 @@ class SymPyValueRangeAnalysis:
# using nan makes subsequent computation throw, and for the purposes of optimization
# returning -math.inf - math.inf is equivalent to giving up
if isinstance(value, SupportsFloat) and math.isnan(value):
return ValueRanges.unknown()
if dtype == torch.bool:
return ValueRanges.unknown_bool()
elif dtype.is_floating_point:
return ValueRanges.unknown()
else:
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
if is_python:
type_ = dtype_to_type(dtype)
@ -369,7 +413,18 @@ class SymPyValueRangeAnalysis:
# dtype is intXX
assert value.is_integer
return ValueRanges.wrap(value)
r = ValueRanges.wrap(value)
return r
@staticmethod
def to_dtype(a, dtype, src_dtype=None):
if dtype == torch.float64:
return ValueRanges.increasing_map(a, ToFloat)
return ValueRanges.unknown()
@staticmethod
def trunc_to_int(a, dtype):
return ValueRanges.increasing_map(a, TruncToInt)
@staticmethod
def not_(a):
@ -428,7 +483,9 @@ class SymPyValueRangeAnalysis:
@staticmethod
def add(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, operator.add)
return ValueRanges.coordinatewise_increasing_map(
a, b, _keep_float(operator.add)
)
@classmethod
def mul(cls, a, b):
@ -448,11 +505,20 @@ class SymPyValueRangeAnalysis:
else:
return a * b
return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul)
return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul))
@classmethod
def div(cls, a, b):
return cls.truediv(a, b)
@staticmethod
def int_truediv(a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if 0 in b or (
(-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
):
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(
a, b, _keep_float(IntTrueDiv)
)
@staticmethod
def truediv(a, b):
@ -463,18 +529,22 @@ class SymPyValueRangeAnalysis:
):
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv)
return ValueRanges.coordinatewise_monotone_map(
a, b, _keep_float(FloatTrueDiv)
)
@staticmethod
def floordiv(a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if 0 in b or (
(-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
# TODO: make this more precise
(-sympy.oo in a or sympy.oo in a)
or (-sympy.oo in b or sympy.oo in b)
):
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv)
return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv)
@classmethod
def mod(cls, x, y):
@ -523,17 +593,51 @@ class SymPyValueRangeAnalysis:
@classmethod
def is_non_overlapping_and_dense_indicator(cls, *args):
return ValueRanges.unknown()
return ValueRanges.unknown() # TODO: type here is wrong
@classmethod
def pow_by_natural(cls, a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if a.is_singleton() and b.is_singleton():
return ValueRanges.wrap(safe_pow(a.lower, b.lower))
# NB: Exclude zero, because zero is special
elif a.lower >= 1:
# We should know that b >= 0 but we may have forgotten this fact due
# to replacements, so don't assert it, but DO clamp it to prevent
# degenerate problems
return ValueRanges.coordinatewise_increasing_map(
a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural
)
elif b.is_singleton():
if b.lower % 2 == 0:
# x^n where n is even
return ValueRanges.convex_min_zero_map(
a, lambda x: safe_pow(x, b.lower)
)
else:
# x^n where n is odd
return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower))
else:
# a is potentially negative, and we don't know if the exponent is
# even or odd. So just conservatively set the upper and lower
# bound based on what the maximum absolute value could be, in both
# directions
max_base = max(a.upper, -a.lower)
return ValueRanges(
-(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper)
)
@classmethod
def pow(cls, a, b):
def is_integer(val):
return isinstance(val, int) or (
hasattr(val, "is_integer") and val.is_integer
)
return ValueRanges.unknown()
# We could implement all this, but for floating point pow, is there
# really a point?
"""
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
# Not implemented yet. It's a bit tricky
# If you want to implement it, compute the partial derivatives of a ** b
# and check the ranges where the function is increasing / decreasing
@ -553,8 +657,7 @@ class SymPyValueRangeAnalysis:
if b == 0:
if not a.lower.is_finite:
return ValueRanges.unknown()
type_ = sympy.Float if a.lower.is_real else sympy.Integer
return ValueRanges.wrap(type_(1))
return ValueRanges.wrap(1.0)
if b < 0:
a = cls.reciprocal(a)
@ -563,21 +666,12 @@ class SymPyValueRangeAnalysis:
if a == ValueRanges.unknown():
return ValueRanges.unknown()
# Here b > 0
if not is_integer(b):
# If the base is positive, then we're good, otherwise nothing's defined
if a.lower >= 0:
return ValueRanges.increasing_map(a, lambda x: x**b)
else:
return ValueRanges.unknown()
# If the base is positive, then we're good, otherwise nothing's defined
if a.lower >= 0:
return ValueRanges.increasing_map(a, lambda x: x**b)
else:
# b > 0 integer
if b % 2 == 0:
# x^n where n is even
return ValueRanges.convex_min_zero_map(a, lambda x: x**b)
else:
# x^n where n is odd
return ValueRanges.increasing_map(a, lambda x: x**b)
return ValueRanges.unknown()
"""
@staticmethod
def reciprocal(x):
@ -586,7 +680,7 @@ class SymPyValueRangeAnalysis:
if 0 in x:
return ValueRanges.unknown()
else:
return ValueRanges.decreasing_map(x, lambda y: 1 / y)
return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y))
@staticmethod
def abs(x):
@ -615,45 +709,64 @@ class SymPyValueRangeAnalysis:
def min_or_max(a, b, fn):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
# Performs upcasting first
def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr:
# Poorman's version of upcasting in Sympy
# Inf is not a float...
if x.is_Integer and y.is_Integer:
result_type = sympy.Integer
elif x.is_rational and y.is_rational:
result_type = sympy.Rational
else:
assert x.is_real or not x.is_finite or y.is_real or not y.is_finite
result_type = sympy.Float
return fn(result_type(x), result_type(y))
return ValueRanges.coordinatewise_increasing_map(a, b, fn_)
return ValueRanges.coordinatewise_increasing_map(a, b, fn)
@classmethod
def floor(cls, x):
def floor_to_int(cls, x, dtype):
return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor)
@classmethod
def ceil(cls, x):
def ceil_to_int(cls, x, dtype):
return ValueRanges.increasing_map(
x, sympy.functions.elementary.integers.ceiling
)
# I think these implementations are sound. The hazard here is that sympy
# will carry out the floor/ceil at too high precision and then something
# bad will happen when we convert it to float.
#
# For truncation, the implementation is clearly sound, because the desired
# target float is always exactly representable, since you're just chopping
# off bits the mantissa. But what about ceil/floor?
#
# The important constraint here is that we're not defining floor on
# arbitrary real numbers, only representable float numbers. So we can
# take advantage of the fact that before we reach the first
# unrepresentable integer in floating point space, we have the range of
# numbers corresponding to exponent zero: all integers, with no fractional
# amounts. floor/ceil is an identity operation in this case. In the
# range below here, representable floating point numbers are spaced
# exactly 1/2 apart, and notably, both the floor/ceil are defined floating
# point numbers. There is no "gap" as you step up to the next exponent.
@classmethod
def round(cls, number, ndigits=None):
if ndigits is None:
fn = Round
else:
assert ndigits.is_singleton()
ndigits = ndigits.lower
# We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
# the second parameter.
fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731
def floor(cls, x):
return ValueRanges.increasing_map(
x, _keep_float(sympy.functions.elementary.integers.floor)
)
@classmethod
def ceil(cls, x):
return ValueRanges.increasing_map(
x, _keep_float(sympy.functions.elementary.integers.ceiling)
)
@classmethod
def round_decimal(cls, number, ndigits):
if not ndigits.is_singleton():
return ValueRanges.unknown()
ndigits = ndigits.lower
# We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
# the second parameter.
fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731
return ValueRanges.increasing_map(number, fn)
@classmethod
def round_to_int(cls, number, dtype):
return ValueRanges.increasing_map(number, RoundToInt)
# It's used in some models on symints
@staticmethod
def sqrt(x):
@ -708,12 +821,15 @@ class SymPyValueRangeAnalysis:
@staticmethod
def cosh(x):
return ValueRanges(0.0, sympy.oo)
"""
x = ValueRanges.wrap(x)
if x.lower > 0:
return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh)
elif x.upper < 0:
return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh)
return ValueRanges(0.0, sympy.oo)
"""
@staticmethod
def sin(x):
@ -723,7 +839,8 @@ class SymPyValueRangeAnalysis:
@staticmethod
def sinh(x):
return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh)
# return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh)
return ValueRanges(-sympy.oo, sympy.oo)
@staticmethod
def tan(x):
@ -731,32 +848,37 @@ class SymPyValueRangeAnalysis:
@staticmethod
def tanh(x):
return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh)
# return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh)
return ValueRanges(-sympy.oo, sympy.oo)
@staticmethod
def asin(x):
return ValueRanges(-sympy.oo, sympy.oo)
"""
x = ValueRanges.wrap(x)
if -1 <= x.lower and x.upper <= 1:
return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh)
return ValueRanges.unknown()
"""
@staticmethod
def acos(x):
return ValueRanges(-sympy.oo, sympy.oo)
"""
x = ValueRanges.wrap(x)
if -1 <= x.lower and x.upper <= 1:
return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos)
return ValueRanges.unknown()
"""
@staticmethod
def atan(x):
return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)
return ValueRanges(-sympy.oo, sympy.oo)
# return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)
@staticmethod
def trunc(x):
def trunc(x):
return sympy.Integer(x) if x.is_finite else x
return ValueRanges.increasing_map(x, trunc)
return ValueRanges.increasing_map(x, TruncToFloat)
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
@ -791,9 +913,10 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
return ValueRanges.unknown()
def index_expr(self, index, dtype):
@classmethod
def index_expr(cls, index, dtype):
assert isinstance(index, ValueRanges)
return index
return cls.to_dtype(index, dtype)
@staticmethod
def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None):
@ -830,12 +953,15 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
@staticmethod
def square(x):
return ValueRanges.convex_min_zero_map(x, lambda y: y * y)
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
@staticmethod
def neg(x):
return ValueRanges.decreasing_map(x, operator.neg)
# TODO: this is slightly inaccurate because truncdiv operates at integer
# precision, but we're going through float truediv which means we can
# potentially lose precision on the bounds
@classmethod
def truncdiv(cls, a, b):
x = cls.truediv(a, b)
@ -856,6 +982,7 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def bound_sympy(
expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None
) -> ValueRanges:
log.debug("bound_sympy(%s, %s)", expr, ranges)
if isinstance(expr, sympy.Number):
return ValueRanges.wrap(expr)