mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano
This commit is contained in:
parent
c1a43a69e4
commit
2f7cfecd86
|
|
@ -49,15 +49,33 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
|||
virtual SymNode mul(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
// NB: legacy, prefer float_truediv or int_truediv
|
||||
virtual SymNode truediv(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymNode float_truediv(const SymNode& other) {
|
||||
return truediv(other);
|
||||
}
|
||||
virtual SymNode int_truediv(const SymNode& other) {
|
||||
return truediv(other);
|
||||
}
|
||||
// NB: legacy, prefer float_pow or pow_by_natural
|
||||
virtual SymNode pow(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymNode float_pow(const SymNode& other) {
|
||||
return pow(other);
|
||||
}
|
||||
virtual SymNode pow_by_natural(const SymNode& other) {
|
||||
return pow(other);
|
||||
}
|
||||
// NB: legacy, prefer int_floordiv
|
||||
virtual SymNode floordiv(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymNode int_floordiv(const SymNode& other) {
|
||||
return floordiv(other);
|
||||
}
|
||||
virtual SymNode mod(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,13 +78,6 @@ for test in tests:
|
|||
del test
|
||||
|
||||
if TEST_Z3:
|
||||
# this only fails when z3 is available
|
||||
unittest.expectedFailure(
|
||||
# SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'.
|
||||
# Ref: https://github.com/sympy/sympy/issues/25146
|
||||
DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821
|
||||
)
|
||||
|
||||
if not config.inline_inbuilt_nn_modules:
|
||||
# TODO model is somehow not being freed when z3 is available
|
||||
unittest.expectedFailure(
|
||||
|
|
|
|||
|
|
@ -2385,8 +2385,7 @@ def forward(self, x):
|
|||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Constraints violated .*!(.*\n)*.*"
|
||||
"by dim0 = 2\\*dim1(.*\n)*.*"
|
||||
"Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*",
|
||||
"Not all values of dim0 .* satisfy the generated guard 4 <= .* and .* <= 10(.*\n)*.*",
|
||||
):
|
||||
torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
|
|
|
|||
|
|
@ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
|
||||
> Right: {0: 0, 1: 1}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
|
||||
> Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
|
||||
> Right: {}
|
||||
==> var_to_sources: values don't match.
|
||||
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
|
||||
|
|
@ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: 2
|
||||
> Right: 0
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)}
|
||||
> Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)}
|
||||
> Right: {}
|
||||
""",
|
||||
)
|
||||
|
|
@ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: {s0: 3}
|
||||
> Right: {}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
|
||||
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
|
||||
> Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
|
||||
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
|
@ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
|
||||
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
|
||||
> Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
|
||||
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
|
@ -9484,10 +9484,7 @@ ShapeEnv not equal: field values don't match:
|
|||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> deferred_runtime_asserts: values don't match.
|
||||
> Left: {u0: [Eq(Mod(u0, 3), 0)]}
|
||||
> Right: {}
|
||||
==> divisible: values don't match.
|
||||
> Left: {Mod(u0, 3)}
|
||||
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
|
||||
> Right: {}
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {_assert, eq, mod, u0}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,12 @@ from torch.testing._internal.common_utils import (
|
|||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal
|
||||
from torch.utils._sympy.functions import (
|
||||
FloorDiv,
|
||||
ModularIndexing,
|
||||
RoundDecimal,
|
||||
RoundToInt,
|
||||
)
|
||||
|
||||
|
||||
class TestIndexingSimplification(InductorTestCase):
|
||||
|
|
@ -168,21 +173,11 @@ class ExprPrinterTests(InductorTestCase):
|
|||
|
||||
common_cases = [
|
||||
# expr, result
|
||||
# Test exprs.
|
||||
(
|
||||
s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1),
|
||||
lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))",
|
||||
),
|
||||
(s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"),
|
||||
# Test Pow directly.
|
||||
(
|
||||
sympy.Pow(s1 + s2, 0),
|
||||
lambda _, L: f"1{L}",
|
||||
), # note: simplified before _print_Pow
|
||||
(
|
||||
sympy.Pow(s1 + s2, -3),
|
||||
lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))",
|
||||
),
|
||||
]
|
||||
|
||||
gpu_cases = common_cases + [
|
||||
|
|
@ -231,12 +226,10 @@ class ExprPrinterTests(InductorTestCase):
|
|||
self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""")
|
||||
|
||||
def test_print_round(self):
|
||||
expr = Round(sympy.Symbol("x", integer=True) / 2)
|
||||
expr = RoundToInt(sympy.Symbol("x", integer=True) / 2)
|
||||
self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""")
|
||||
self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
|
||||
self.assertExpectedInline(
|
||||
texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)"""
|
||||
)
|
||||
self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
|
||||
|
||||
@parametrize("ndigits", [-1, 0, 1])
|
||||
def test_print_round_decimal(self, ndigits):
|
||||
|
|
@ -251,45 +244,18 @@ class ExprPrinterTests(InductorTestCase):
|
|||
f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}",
|
||||
)
|
||||
|
||||
expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits)
|
||||
if ndigits >= 0:
|
||||
for do_print in [pexpr, cexpr, texpr]:
|
||||
self.assertEqual(do_print(expr), "x")
|
||||
else:
|
||||
self.assertEqual(pexpr(expr), f"round(x, {ndigits})")
|
||||
for do_print in [cexpr, texpr]:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "only non-negative ndigits are currently supported"
|
||||
):
|
||||
do_print(expr)
|
||||
|
||||
def test_print_floor_div(self):
|
||||
for integer in [True, False]:
|
||||
s1 = sympy.Symbol("s1", integer=integer)
|
||||
s2 = sympy.Symbol("s2", integer=integer)
|
||||
expr = FloorDiv(s1, s2)
|
||||
self.assertEqual(pexpr(expr), "(s1 // s2)")
|
||||
if integer:
|
||||
self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)")
|
||||
else:
|
||||
self.assertEqual(
|
||||
cexpr(expr),
|
||||
"c10::div_floor_floating(static_cast<double>(s1), static_cast<double>(s2))",
|
||||
)
|
||||
s1 = sympy.Symbol("s1", integer=True)
|
||||
s2 = sympy.Symbol("s2", integer=True)
|
||||
expr = FloorDiv(s1, s2)
|
||||
self.assertEqual(pexpr(expr), "(s1 // s2)")
|
||||
self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)")
|
||||
|
||||
for integer in [True, False]:
|
||||
s1 = sympy.Symbol("s1", integer=integer)
|
||||
s2 = sympy.S(-1)
|
||||
expr = FloorDiv(s1, s2)
|
||||
if integer:
|
||||
self.assertEqual(pexpr(expr), "(-1)*s1")
|
||||
self.assertEqual(cexpr(expr), "(-1L)*s1")
|
||||
else:
|
||||
self.assertEqual(pexpr(expr), "(s1 // (-1))")
|
||||
self.assertEqual(
|
||||
cexpr(expr),
|
||||
"c10::div_floor_floating(static_cast<double>(s1), static_cast<double>((-1L)))",
|
||||
)
|
||||
s1 = sympy.Symbol("s1", integer=True)
|
||||
s2 = sympy.S(-1)
|
||||
expr = FloorDiv(s1, s2)
|
||||
self.assertEqual(pexpr(expr), "(-1)*s1")
|
||||
self.assertEqual(cexpr(expr), "(-1L)*s1")
|
||||
|
||||
def test_print_Min_Max(self):
|
||||
cases = (
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import contextlib
|
|||
import importlib
|
||||
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
@ -649,6 +650,33 @@ class TestInductorDynamic(TestCase):
|
|||
actual = cfn(5)
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def test_interpolate_ceil_eq(self, device):
|
||||
ceiling = math.ceil
|
||||
IntTrueDiv = operator.truediv
|
||||
|
||||
def fn(t):
|
||||
s0, s2, s3 = t.size()
|
||||
x = torch.zeros(
|
||||
(
|
||||
s0,
|
||||
2048,
|
||||
ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)),
|
||||
ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)),
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
return torch.nn.functional.interpolate(
|
||||
x,
|
||||
scale_factor=2,
|
||||
mode="nearest",
|
||||
)
|
||||
|
||||
cfn = self.compile_fn(fn)
|
||||
arg = torch.randn(4, 16, 18)
|
||||
expect = fn(arg)
|
||||
actual = cfn(arg)
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def test_full_recompiles(self, device):
|
||||
def fn(x):
|
||||
_, L = x.shape
|
||||
|
|
|
|||
|
|
@ -158,8 +158,12 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
torch.tensor([operator.sub(x.item(), y.item())]),
|
||||
torch.tensor([operator.mul(x.item(), y.item())]),
|
||||
torch.tensor([operator.truediv(x.item(), y.item())]),
|
||||
torch.tensor([operator.floordiv(x.item(), y.item())]),
|
||||
torch.tensor([operator.pow(x.item(), y.item())]),
|
||||
# This requires torch.sym_float, probably easy to lower to
|
||||
# ONNX but I don't know where to put it
|
||||
# torch.tensor([operator.floordiv(x.item(), y.item())]),
|
||||
# NB: abs so that the base and exponent are provably
|
||||
# non-negative, so we don't generate runtime asserts
|
||||
torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]),
|
||||
torch.tensor([operator.abs(x.item())]),
|
||||
torch.tensor([operator.neg(x.item())]),
|
||||
torch.tensor([math.ceil(x.item())]),
|
||||
|
|
|
|||
|
|
@ -205,15 +205,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True):
|
|||
|
||||
|
||||
# TODO: default duck to False
|
||||
def create_symint(shape_env, i: int, duck=True):
|
||||
def create_symint(shape_env, i: int, duck=True) -> SymInt:
|
||||
return create_symtype(SymInt, int, shape_env, i, duck=duck)
|
||||
|
||||
|
||||
def create_symbool(shape_env, b: bool):
|
||||
def create_symbool(shape_env, b: bool) -> SymBool:
|
||||
return create_symtype(SymBool, bool, shape_env, b)
|
||||
|
||||
|
||||
def create_symfloat(shape_env, f: float):
|
||||
def create_symfloat(shape_env, f: float) -> SymFloat:
|
||||
return create_symtype(SymFloat, float, shape_env, f)
|
||||
|
||||
|
||||
|
|
@ -457,14 +457,16 @@ class TestPySymInt(TestCase):
|
|||
r = sym_int(a1 / 2)
|
||||
self.assertEqual(guard_int(r), 3)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""")
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
|
||||
)
|
||||
|
||||
a3 = create_symint(shape_env, 3)
|
||||
r = sym_int(2.0 * torch.sym_float(a3))
|
||||
self.assertEqual(guard_int(r), 6)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)"""
|
||||
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
|
||||
)
|
||||
|
||||
def test_sym_sqrt(self):
|
||||
|
|
@ -474,7 +476,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)"""
|
||||
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)"""
|
||||
)
|
||||
|
||||
def test_sym_floor(self):
|
||||
|
|
@ -483,11 +485,17 @@ class TestPySymInt(TestCase):
|
|||
r = math.floor(a0 / 2)
|
||||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""")
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]),
|
||||
"""Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
|
||||
)
|
||||
r = math.floor(3.0 * a0)
|
||||
self.assertEqual(r, 15)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]),
|
||||
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
|
||||
)
|
||||
|
||||
def test_sym_trunc(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
|
@ -495,12 +503,14 @@ class TestPySymInt(TestCase):
|
|||
r = math.trunc(a0 / 2)
|
||||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""")
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
|
||||
)
|
||||
r = torch.sym_int(torch.sym_sqrt(a0))
|
||||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)"""
|
||||
str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)"""
|
||||
)
|
||||
|
||||
def test_sym_ceil(self):
|
||||
|
|
@ -510,12 +520,17 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(r, 3)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)"""
|
||||
str(shape_env.guards[0][0]),
|
||||
"""Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
|
||||
)
|
||||
r = math.floor(3.0 * a0)
|
||||
r1 = 3.0 * a0
|
||||
r = math.floor(r1)
|
||||
self.assertEqual(r, 15)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]),
|
||||
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
|
||||
)
|
||||
|
||||
def test_sym_ite(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
|
@ -962,8 +977,14 @@ class f(torch.nn.Module):
|
|||
)
|
||||
class TestSymNumberMagicMethods(TestCase):
|
||||
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
||||
with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
|
||||
return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
|
||||
|
||||
def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
||||
# Helper function
|
||||
# NB: don't use one as that will get specialized
|
||||
# TODO: We don't have to circuitously create the float, can just
|
||||
# create a symfloat directly
|
||||
seed_node = (create_symint(shape_env, 2) / 2.0).node
|
||||
bool_seed_node = (create_symint(shape_env, 2) == 2).node
|
||||
|
||||
|
|
@ -976,27 +997,42 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
else:
|
||||
return torch.SymFloat(to_node(seed_node, inp))
|
||||
|
||||
if fn == "float_pow":
|
||||
if inp1 < 0:
|
||||
return
|
||||
|
||||
if fn == "pow_by_natural":
|
||||
if isinstance(inp1, float) or isinstance(inp2, float):
|
||||
return
|
||||
if inp2 < 0:
|
||||
return
|
||||
|
||||
def maybe_xfail(inp1, inp2):
|
||||
if fn == "sym_sqrt" and inp1 < 0:
|
||||
# ValueError: math domain error
|
||||
return self.assertRaises((ValueError,))
|
||||
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
|
||||
elif (
|
||||
fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
|
||||
and inp2 == 0
|
||||
):
|
||||
# ZeroDivisionError: division by zero
|
||||
return self.assertRaises((ZeroDivisionError,))
|
||||
elif fn == "pow" and inp1 == 0 and inp2 < 0:
|
||||
elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
|
||||
# ZeroDivisionError: 0.0 cannot be raised to a negative power
|
||||
return self.assertRaises((ZeroDivisionError,))
|
||||
elif (
|
||||
fn == "pow"
|
||||
# TODO: dear catastrophe waitress,
|
||||
# this doesn't work
|
||||
fn in ["float_pow", "pow_by_natural"]
|
||||
and inp1 < 0
|
||||
and inp2 in (2.5, -2.5)
|
||||
and (
|
||||
type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt)
|
||||
type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
|
||||
)
|
||||
and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
|
||||
):
|
||||
# Complex result, which we do not support:
|
||||
# TypeError: Cannot convert complex to float
|
||||
return self.assertRaises((TypeError,))
|
||||
return self.assertRaises((RuntimeError,))
|
||||
elif fn in ("lshift", "rshift") and not (
|
||||
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
|
||||
):
|
||||
|
|
@ -1080,6 +1116,9 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
) and fn in sym_node.only_float_magic_methods:
|
||||
self.skipTest(f"{fn} is not an int method")
|
||||
|
||||
if second_type == "float" and fn in ["mod"]:
|
||||
self.skipTest(f"{fn} only handles int")
|
||||
|
||||
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
|
||||
# Second argument is ignored for unary function. So only run for one type
|
||||
if is_unary_fn and second_type == "float":
|
||||
|
|
@ -1251,112 +1290,15 @@ class TestFloorDiv(TestCase):
|
|||
yield (-x, -y)
|
||||
|
||||
def test_floordiv_float_int(self):
|
||||
values = (
|
||||
(2.5, 2.1),
|
||||
(2.1, 2.5),
|
||||
(2.0, 2.1),
|
||||
(7, 2.5),
|
||||
(2.1, 7),
|
||||
(7, 2),
|
||||
)
|
||||
values = ((7, 2),)
|
||||
|
||||
for x, y in TestFloorDiv.yield_test_cases(values):
|
||||
self.assertEqual(
|
||||
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
|
||||
)
|
||||
|
||||
def test_floordiv_bool(self):
|
||||
values = (
|
||||
(False, True),
|
||||
(True, 2.5),
|
||||
(2.5, True),
|
||||
(False, 7),
|
||||
(7, True),
|
||||
)
|
||||
|
||||
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
|
||||
# Compares to int since our FloorDiv has no bool support
|
||||
self.assertEqual(
|
||||
TestFloorDiv.python_floordiv(x, y),
|
||||
TestFloorDiv.torch_floordiv(int(x), int(y)),
|
||||
)
|
||||
# Tests that our impl throws
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(
|
||||
rf"unsupported operand type\(s\) for //: "
|
||||
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
|
||||
rf", expected integer or real"
|
||||
),
|
||||
lambda: TestFloorDiv.torch_floordiv(x, y),
|
||||
)
|
||||
|
||||
def test_floordiv_complex(self):
|
||||
values = (
|
||||
(1.5 + 2.5j, 1.3 + 3.5j),
|
||||
(1.5 + 2.5j, 2.5),
|
||||
(2.5, 1.5 + 2.5j),
|
||||
(1.5 + 2.5j, 7),
|
||||
(7, 1.5 + 2.5j),
|
||||
)
|
||||
|
||||
for x, y in TestFloorDiv.yield_test_cases(values):
|
||||
# We don't test error messages to avoid depending on Python
|
||||
# interpreter version
|
||||
self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y))
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(
|
||||
rf"unsupported operand type\(s\) for //: "
|
||||
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
|
||||
rf", expected integer or real"
|
||||
),
|
||||
lambda: TestFloorDiv.torch_floordiv(x, y),
|
||||
)
|
||||
|
||||
def test_floordiv_div_by_zero(self):
|
||||
values = (
|
||||
(2.5, 0),
|
||||
(2.1, 0.0),
|
||||
(2.3, sympy.Symbol("s", zero=True)),
|
||||
)
|
||||
|
||||
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
|
||||
# We don't test error messages to avoid depending on Python
|
||||
# interpreter version
|
||||
if type(y) is not sympy.Symbol:
|
||||
self.assertRaises(
|
||||
ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y)
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ZeroDivisionError,
|
||||
"division by zero",
|
||||
lambda: TestFloorDiv.torch_floordiv(x, y),
|
||||
)
|
||||
|
||||
def test_floordiv_zero_base(self):
|
||||
values = (
|
||||
(0, 2.5),
|
||||
(0.0, 2.1),
|
||||
(sympy.Symbol("s", zero=True), 2.3),
|
||||
)
|
||||
|
||||
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
|
||||
if type(x) is not sympy.Symbol:
|
||||
self.assertEqual(
|
||||
TestFloorDiv.python_floordiv(x, y),
|
||||
TestFloorDiv.torch_floordiv(x, y),
|
||||
)
|
||||
else:
|
||||
self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y))
|
||||
|
||||
def test_floordiv_div_by_one(self):
|
||||
values = (
|
||||
(2.5, 1),
|
||||
(2.1, 1.0),
|
||||
(2, 1.0),
|
||||
(2, 1),
|
||||
)
|
||||
values = ((2, 1),)
|
||||
|
||||
for x, y in TestFloorDiv.yield_test_cases(values):
|
||||
self.assertEqual(
|
||||
|
|
@ -1367,12 +1309,7 @@ class TestFloorDiv(TestCase):
|
|||
# Tests how we simplify or evaluate FloorDiv without free variables
|
||||
shape_env = ShapeEnv()
|
||||
result = 21
|
||||
exprs = (
|
||||
7 * FloorDiv(6, 2),
|
||||
7 * FloorDiv(6.28, 2),
|
||||
7 * FloorDiv(6.28, 2.0),
|
||||
7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))),
|
||||
)
|
||||
exprs = (7 * FloorDiv(6, 2),)
|
||||
|
||||
for expr in exprs:
|
||||
self.assertEqual(expr, result)
|
||||
|
|
@ -1382,33 +1319,10 @@ class TestFloorDiv(TestCase):
|
|||
self.assertEqual(shape_env.simplify(expr), result)
|
||||
self.assertEqual(shape_env.evaluate_expr(expr), result)
|
||||
|
||||
def test_floordiv_simplify_rational(self):
|
||||
result = 21
|
||||
|
||||
a = sympy.Symbol("a", integer=True)
|
||||
b = sympy.Symbol("b")
|
||||
|
||||
cases = [
|
||||
(FloorDiv(a, sympy.Rational(1, 8)), 8 * a),
|
||||
(FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)),
|
||||
]
|
||||
|
||||
for expr, expected in cases:
|
||||
self.assertEqual(expr, expected)
|
||||
|
||||
def test_floordiv_assumptions(self):
|
||||
# We define two Symbols (with different names) for each type to make
|
||||
# sure the behavior is consistent regardless of whether both arguments
|
||||
# are the same object or not.
|
||||
cases = (
|
||||
sympy.Symbol("i1", integer=True),
|
||||
sympy.Symbol("i2", integer=True),
|
||||
sympy.Symbol("r1", real=True),
|
||||
sympy.Symbol("r2", real=True),
|
||||
sympy.Symbol("c1", complex=True, real=False, integer=False),
|
||||
sympy.Symbol("c2", complex=True, real=False, integer=False),
|
||||
sympy.Symbol("s1"),
|
||||
sympy.Symbol("s2"),
|
||||
)
|
||||
|
||||
for base, divisor in itertools.product(cases, repeat=2):
|
||||
|
|
|
|||
|
|
@ -1618,7 +1618,8 @@ def forward(self, lengths_1, values_1):
|
|||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
pow_1 = sym_size_int ** 0.5; sym_size_int = None
|
||||
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
|
||||
pow_1 = sym_float ** 0.5; sym_float = None
|
||||
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
|
||||
return div""")
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,12 @@ UNARY_OPS = [
|
|||
"floor",
|
||||
"ceil",
|
||||
]
|
||||
BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"]
|
||||
BINARY_OPS = [
|
||||
"truediv", "floordiv",
|
||||
# "truncdiv", # TODO
|
||||
# NB: pow is float_pow
|
||||
"add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod"
|
||||
]
|
||||
|
||||
UNARY_BOOL_OPS = ["not_"]
|
||||
BINARY_BOOL_OPS = ["or_", "and_"]
|
||||
|
|
@ -81,16 +86,24 @@ def valid_unary(fn, v):
|
|||
|
||||
def valid_binary(fn, a, b):
|
||||
if fn == "pow" and (
|
||||
# sympy will expand to x*x*... for integral b; don't do it if it's big
|
||||
b > 4
|
||||
or ( # sympy will expand to x*x*... for integral b; don't do it if it's big
|
||||
a <= 0 and b == -1
|
||||
)
|
||||
or (a == b == 0) # no imaginary numbers # 0**0 is undefined
|
||||
# no imaginary numbers
|
||||
or a <= 0
|
||||
# 0**0 is undefined
|
||||
or (a == b == 0)
|
||||
):
|
||||
return False
|
||||
elif fn == "mod" and b == 0:
|
||||
elif fn == "pow_by_natural" and (
|
||||
# sympy will expand to x*x*... for integral b; don't do it if it's big
|
||||
b > 4
|
||||
or b < 0
|
||||
or (a == b == 0)
|
||||
):
|
||||
return False
|
||||
elif (fn == "div" or fn == "truediv") and b == 0:
|
||||
elif fn == "mod" and (a < 0 or b <= 0):
|
||||
return False
|
||||
elif (fn in ["div", "truediv", "floordiv"]) and b == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
@ -130,27 +143,26 @@ class TestValueRanges(TestCase):
|
|||
ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
|
||||
|
||||
@parametrize("fn", BINARY_OPS)
|
||||
@parametrize("dtype_a", ("int", "float"))
|
||||
@parametrize("dtype_b", ("int", "float"))
|
||||
def test_binary_ref(self, fn, dtype_a, dtype_b):
|
||||
@parametrize("dtype", ("int", "float"))
|
||||
def test_binary_ref(self, fn, dtype):
|
||||
to_dtype = {"int": sympy.Integer, "float": sympy.Float}
|
||||
dtype_a = to_dtype[dtype_a]
|
||||
dtype_b = to_dtype[dtype_b]
|
||||
# Don't test float on int only methods
|
||||
if dtype == "float" and fn in ["pow_by_natural", "mod"]:
|
||||
return
|
||||
dtype = to_dtype[dtype]
|
||||
for a, b in itertools.product(CONSTANTS, repeat=2):
|
||||
if not valid_binary(fn, a, b):
|
||||
continue
|
||||
a = dtype_a(a)
|
||||
b = dtype_b(b)
|
||||
a = dtype(a)
|
||||
b = dtype(b)
|
||||
with self.subTest(a=a, b=b):
|
||||
r = getattr(ValueRangeAnalysis, fn)(a, b)
|
||||
if r == ValueRanges.unknown():
|
||||
continue
|
||||
ref_r = getattr(ReferenceAnalysis, fn)(a, b)
|
||||
|
||||
# sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf
|
||||
if fn != "floordiv":
|
||||
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
|
||||
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
|
||||
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
|
||||
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
|
||||
self.assertEqual(r.lower, r.upper)
|
||||
self.assertEqual(ref_r, r.lower)
|
||||
|
||||
|
|
@ -200,7 +212,8 @@ class TestValueRanges(TestCase):
|
|||
|
||||
@parametrize("fn", UNARY_OPS)
|
||||
def test_unary_ref_range(self, fn):
|
||||
vals = [-sympy.oo, *CONSTANTS, sympy.oo]
|
||||
# TODO: bring back sympy.oo testing for float unary fns
|
||||
vals = CONSTANTS
|
||||
for a in generate_range(vals):
|
||||
with self.subTest(a=a):
|
||||
ref_r = getattr(ValueRangeAnalysis, fn)(a)
|
||||
|
|
@ -216,40 +229,26 @@ class TestValueRanges(TestCase):
|
|||
# This takes about 4s for all the variants
|
||||
@parametrize("fn", BINARY_OPS + COMPARE_OPS)
|
||||
def test_binary_ref_range(self, fn):
|
||||
vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo]
|
||||
# TODO: bring back sympy.oo testing for float unary fns
|
||||
vals = LESS_CONSTANTS
|
||||
for a, b in itertools.product(generate_range(vals), repeat=2):
|
||||
# don't attempt pow on exponents that are too large (but oo is OK)
|
||||
if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
|
||||
continue
|
||||
with self.subTest(a=a, b=b):
|
||||
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
||||
for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
|
||||
if a0 not in a or b0 not in b:
|
||||
continue
|
||||
if not valid_binary(fn, a0, b0):
|
||||
continue
|
||||
with self.subTest(a0=a0, b0=b0):
|
||||
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
||||
r = getattr(ReferenceAnalysis, fn)(
|
||||
sympy.Integer(a0), sympy.Integer(b0)
|
||||
)
|
||||
if r.is_finite:
|
||||
self.assertIn(r, ref_r)
|
||||
|
||||
def test_rational_bounds(self):
|
||||
# Repro from https://github.com/pytorch/pytorch/issues/105097
|
||||
from sympy import floor, Eq
|
||||
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
|
||||
new_expr = (
|
||||
Eq(30 * floor(4 * ((shape_0 + 1) // 96) *
|
||||
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 +
|
||||
2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647),
|
||||
2880 * floor(((shape_0 + 1) // 96) *
|
||||
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 +
|
||||
323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764)))
|
||||
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
|
||||
self.assertTrue(new_expr.subs({shape_0: 95}))
|
||||
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
|
||||
|
||||
|
||||
class TestSympyInterp(TestCase):
|
||||
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
|
||||
|
|
@ -258,7 +257,13 @@ class TestSympyInterp(TestCase):
|
|||
if fn in ("div", "truncdiv", "minimum", "maximum", "mod"):
|
||||
return
|
||||
|
||||
from sympy.abc import x, y
|
||||
is_integer = None
|
||||
if fn == "pow_by_natural":
|
||||
is_integer = True
|
||||
|
||||
x = sympy.Dummy('x', integer=is_integer)
|
||||
y = sympy.Dummy('y', integer=is_integer)
|
||||
|
||||
vals = CONSTANTS
|
||||
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
||||
vals = [True, False]
|
||||
|
|
@ -300,29 +305,17 @@ class TestSympyInterp(TestCase):
|
|||
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
||||
arity = 2
|
||||
|
||||
from sympy.abc import x, y
|
||||
is_integer = None
|
||||
if fn == "pow_by_natural":
|
||||
is_integer = True
|
||||
|
||||
x = sympy.Dummy('x', integer=is_integer)
|
||||
y = sympy.Dummy('y', integer=is_integer)
|
||||
|
||||
symbols = [x]
|
||||
if arity == 2:
|
||||
symbols = [x, y]
|
||||
|
||||
# Workaround mpf from symbol error
|
||||
if fn == "minimum":
|
||||
sympy_expr = sympy.Min(x, y)
|
||||
elif fn == "maximum":
|
||||
sympy_expr = sympy.Max(x, y)
|
||||
else:
|
||||
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
||||
|
||||
if arity == 1:
|
||||
def trace_f(px):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
|
||||
else:
|
||||
def trace_f(px, py):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
|
||||
|
||||
gm = fx.symbolic_trace(trace_f)
|
||||
|
||||
for args in itertools.product(vals, repeat=arity):
|
||||
if arity == 1 and not valid_unary(fn, *args):
|
||||
continue
|
||||
|
|
@ -330,11 +323,28 @@ class TestSympyInterp(TestCase):
|
|||
continue
|
||||
if fn == "truncdiv" and args[1] == 0:
|
||||
continue
|
||||
elif fn == "pow" and (args[0] == 0 and args[1] <= 0):
|
||||
elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0):
|
||||
continue
|
||||
elif fn == "floordiv" and args[1] == 0:
|
||||
continue
|
||||
with self.subTest(args=args):
|
||||
# Workaround mpf from symbol error
|
||||
if fn == "minimum":
|
||||
sympy_expr = sympy.Min(x, y)
|
||||
elif fn == "maximum":
|
||||
sympy_expr = sympy.Max(x, y)
|
||||
else:
|
||||
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
||||
|
||||
if arity == 1:
|
||||
def trace_f(px):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
|
||||
else:
|
||||
def trace_f(px, py):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
|
||||
|
||||
gm = fx.symbolic_trace(trace_f)
|
||||
|
||||
self.assertEqual(
|
||||
sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr),
|
||||
gm(*args)
|
||||
|
|
|
|||
|
|
@ -316,6 +316,75 @@ class SymInt:
|
|||
|
||||
# Magic methods installed by torch.fx.experimental.sym_node
|
||||
|
||||
def __round__(self, ndigits=None):
|
||||
return self
|
||||
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, (builtins.float, SymFloat)):
|
||||
return sym_float(self).__float_truediv__(other)
|
||||
if not isinstance(other, (builtins.int, SymInt)):
|
||||
return NotImplemented
|
||||
return self.__int_truediv__(other)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
if isinstance(other, (builtins.float, SymFloat)):
|
||||
return sym_float(self).__rfloat_truediv__(other)
|
||||
if not isinstance(other, (builtins.int, SymInt)):
|
||||
return NotImplemented
|
||||
return self.__rint_truediv__(other)
|
||||
|
||||
def __floordiv__(self, other):
|
||||
if isinstance(other, (builtins.float, SymFloat)):
|
||||
return torch.sym_float(math.floor(sym_float(self) / other))
|
||||
if not isinstance(other, (builtins.int, SymInt)):
|
||||
return NotImplemented
|
||||
return self.__int_floordiv__(other)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
if isinstance(other, (builtins.float, SymFloat)):
|
||||
return torch.sym_float(math.floor(other / sym_float(self)))
|
||||
if not isinstance(other, (builtins.int, SymInt)):
|
||||
return NotImplemented
|
||||
return self.__rint_floordiv__(other)
|
||||
|
||||
# nb: complex is impossible to handle correctly lol, with
|
||||
# negative base and integral float need to diverge semantics and
|
||||
# just always return complex. Neener neener pretend this problem
|
||||
# doesn't exist
|
||||
def __pow__(self, other):
|
||||
if isinstance(other, (builtins.float, SymFloat)):
|
||||
return sym_float(self).__pow__(other)
|
||||
if not isinstance(other, (builtins.int, SymInt)):
|
||||
return NotImplemented
|
||||
# Guards! This guard is necessary because we need to know it to
|
||||
# determine the output type of this operation
|
||||
if other >= 0:
|
||||
return self.__pow_by_natural__(other)
|
||||
else:
|
||||
# Mercifully, when the exponent is negative, Python just promotes
|
||||
# to doubles and does a float pow:
|
||||
#
|
||||
# if (Py_SIZE(b) < 0 && c == NULL) {
|
||||
# /* if exponent is negative and there's no modulus:
|
||||
# return a float. This works because we know
|
||||
# that this calls float_pow() which converts its
|
||||
# arguments to double. */
|
||||
# Py_DECREF(a);
|
||||
# Py_DECREF(b);
|
||||
# return PyFloat_Type.tp_as_number->nb_power(v, w, x);
|
||||
# }
|
||||
return sym_float(self).__pow__(sym_float(other))
|
||||
|
||||
def __rpow__(self, other):
|
||||
if isinstance(other, (builtins.float, SymFloat)):
|
||||
return sym_float(self).__rpow__(other)
|
||||
if not isinstance(other, (builtins.int, SymInt)):
|
||||
return NotImplemented
|
||||
if self >= 0: # self is exponent
|
||||
return self.__rpow_by_natural__(other)
|
||||
else:
|
||||
return sym_float(self).__rpow__(sym_float(other))
|
||||
|
||||
def __eq__(self, other: object) -> builtins.bool:
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
|
|
@ -337,6 +406,24 @@ class SymInt:
|
|||
def __mul__(self, other) -> "SymInt":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __pow_by_natural__(self, other) -> "SymInt":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __rpow_by_natural__(self, other) -> "SymInt":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __int_truediv__(self, other) -> "SymFloat":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __rint_truediv__(self, other) -> "SymFloat":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __int_floordiv__(self, other) -> "SymFloat":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __rint_floordiv__(self, other) -> "SymFloat":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_max__(self, other):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
|
|
@ -371,9 +458,43 @@ class SymFloat:
|
|||
# class has a field named node that stores SymNode
|
||||
self.node = node
|
||||
|
||||
def __truediv__(self, other):
|
||||
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
||||
return NotImplemented
|
||||
return self.__float_truediv__(sym_float(other))
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
||||
return NotImplemented
|
||||
return self.__rfloat_truediv__(sym_float(other))
|
||||
|
||||
def __floordiv__(self, other):
|
||||
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
||||
return NotImplemented
|
||||
return torch.sym_float(math.floor(self / sym_float(other)))
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
||||
return NotImplemented
|
||||
return torch.sym_float(math.floor(sym_float(other) / self))
|
||||
|
||||
def __bool__(self):
|
||||
return self.node.bool_()
|
||||
|
||||
# Symbolic power does NOT work with negative base, this is to avoid
|
||||
# potential complex outputs
|
||||
def __pow__(self, other):
|
||||
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
||||
return NotImplemented
|
||||
torch._check(self >= 0)
|
||||
return self.__float_pow__(other)
|
||||
|
||||
def __rpow__(self, other):
|
||||
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
||||
return NotImplemented
|
||||
torch._check(other >= 0)
|
||||
return self.__rfloat_pow__(other)
|
||||
|
||||
# Magic methods installed by torch.fx.experimental.sym_node
|
||||
|
||||
def __eq__(self, other: object) -> builtins.bool:
|
||||
|
|
@ -391,6 +512,18 @@ class SymFloat:
|
|||
def __ge__(self, other) -> builtins.bool:
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __float_pow__(self, other) -> "SymFloat":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __rfloat_pow__(self, other) -> "SymFloat":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __float_truediv__(self, other) -> "SymFloat":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __rfloat_truediv__(self, other) -> "SymFloat":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __trunc__(self):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
|
|
@ -524,7 +657,12 @@ def sym_int(a):
|
|||
return py_int(a) # type: ignore[operator]
|
||||
|
||||
def sym_max(a, b):
|
||||
""" SymInt-aware utility for max()."""
|
||||
"""
|
||||
SymInt-aware utility for max which avoids branching on a < b.
|
||||
Unlike builtins.max(), this only works for int/float, and it always
|
||||
promotes to float if any argument is float (unlike builtins.max, which
|
||||
will faithfully preserve the type of the input argument).
|
||||
"""
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if has_torch_function((a, b)):
|
||||
|
|
@ -532,14 +670,19 @@ def sym_max(a, b):
|
|||
if isinstance(a, (SymInt, SymFloat)):
|
||||
return a.__sym_max__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
# NB: If you actually care about preserving output type exactly
|
||||
# if you do something like max(0, 0.0), it is NOT sound to treat
|
||||
# min/max as commutative
|
||||
# Due to promotion semantics, this is operator is commutative:
|
||||
# max(1, 1.0) === max(1.0, 1) === 1.0
|
||||
return b.__sym_max__(a)
|
||||
return builtins.max(a, b) # type: ignore[operator]
|
||||
# TODO: Probably can make bool work too, just lazy
|
||||
assert isinstance(a, (builtins.int, builtins.float)), type(a)
|
||||
assert isinstance(b, (builtins.int, builtins.float)), type(b)
|
||||
if isinstance(a, builtins.float) or isinstance(b, builtins.float):
|
||||
return builtins.float(builtins.max(a, b))
|
||||
else:
|
||||
return builtins.max(a, b)
|
||||
|
||||
def sym_min(a, b):
|
||||
""" SymInt-aware utility for max()."""
|
||||
""" SymInt-aware utility for min()."""
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if has_torch_function((a, b)):
|
||||
|
|
@ -548,7 +691,12 @@ def sym_min(a, b):
|
|||
return a.__sym_min__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
return b.__sym_min__(a)
|
||||
return builtins.min(a, b) # type: ignore[operator]
|
||||
assert isinstance(a, (builtins.int, builtins.float)), type(a)
|
||||
assert isinstance(b, (builtins.int, builtins.float)), type(b)
|
||||
if isinstance(a, builtins.float) or isinstance(b, builtins.float):
|
||||
return builtins.float(builtins.min(a, b))
|
||||
else:
|
||||
return builtins.min(a, b)
|
||||
|
||||
# Drop in replacement for math.sqrt, math.sin, math.cos etc
|
||||
current_module = sys.modules[__name__]
|
||||
|
|
|
|||
|
|
@ -1474,10 +1474,15 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
# Here we force symbols corresponding to SymInts to be at least integers.
|
||||
# Otherwise some expressions that the shape env would otherwise evaluate to False,
|
||||
# e.g., 2*s = 9, can have rational solutions, e.g., 9/2.
|
||||
# TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024)
|
||||
sym = sym.subs(
|
||||
{s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols}
|
||||
)
|
||||
if isinstance(sym, sympy.Symbol):
|
||||
# We need to check if the symbol has already been allocated,
|
||||
# self.symbol_name_to_symbol is not enough because the
|
||||
# integer-ification of symbols can induce simplification;
|
||||
# e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral
|
||||
if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val:
|
||||
self.symbol_name_to_symbol[val.expr_str] = sym
|
||||
if hint is not None:
|
||||
self.shape_env.add_var_to_val(sym, hint)
|
||||
|
|
@ -1496,7 +1501,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
free_symbols = sym.free_symbols
|
||||
for s in free_symbols:
|
||||
if s.name not in self.symbol_name_to_symbol:
|
||||
self.symbol_name_to_symbol[s.name] = s
|
||||
self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment]
|
||||
if vr := self.symbol_name_to_range.get(s.name):
|
||||
self.shape_env.constrain_symbol_range(
|
||||
s,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict
|
||||
|
|
@ -11,6 +12,9 @@ from .utils import cache_on_self, dominated_nodes
|
|||
from .virtualized import V
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BoundVars:
|
||||
"""
|
||||
Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
|
||||
|
|
@ -55,6 +59,7 @@ class BoundVars:
|
|||
|
||||
with V.set_ops_handler(ValueRangeAnalysis()):
|
||||
interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
|
||||
log.debug("get_bounds:\n%s", self.loop_body.root_block.graph)
|
||||
interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
|
||||
return self._bounds
|
||||
|
||||
|
|
|
|||
|
|
@ -340,6 +340,8 @@ class DataTypePropagation:
|
|||
DataTypePropagation.propagate_loopbody(node._body)
|
||||
|
||||
|
||||
# This printer contains rules that are supposed to be generic for both C/C++ and
|
||||
# Python
|
||||
class ExprPrinter(Printer):
|
||||
@staticmethod
|
||||
def paren(string):
|
||||
|
|
@ -369,12 +371,6 @@ class ExprPrinter(Printer):
|
|||
return string
|
||||
return f"({string})"
|
||||
|
||||
def _print_Infinity(self, expr):
|
||||
return "math.inf"
|
||||
|
||||
def _print_NegativeInfinity(self, expr):
|
||||
return "-math.inf"
|
||||
|
||||
def _print_Relational(self, expr):
|
||||
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
|
|
@ -384,11 +380,14 @@ class ExprPrinter(Printer):
|
|||
def _print_Add(self, expr):
|
||||
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
# NB: this is OK to put here, because Mod is only defined for positive
|
||||
# numbers, and so across C/Python its behavior is consistent
|
||||
def _print_Mod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
||||
def _print_FloatTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
|
||||
def _print_CleanDiv(self, expr):
|
||||
return self._print_FloorDiv(expr)
|
||||
|
|
@ -399,10 +398,84 @@ class ExprPrinter(Printer):
|
|||
# Go figure...
|
||||
return " >= ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
# NB: The C implementation is injected into codegen at
|
||||
# torch/_inductor/codegen/wrapper.py
|
||||
def _print_align(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"align({self._print(expr.args[0])})"
|
||||
|
||||
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
|
||||
# any explicit intervention. We print it just like x * x, notably, we
|
||||
# never generate sympy.Pow with floats.
|
||||
#
|
||||
# NB: this pow by natural, you should never have used builtin sympy.pow
|
||||
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
||||
# means exp is guaranteed to be integer.
|
||||
def _print_Pow(self, expr):
|
||||
base, exp = expr.args
|
||||
base = self._print(base)
|
||||
assert exp == int(exp), exp
|
||||
exp = int(exp)
|
||||
assert exp >= 0
|
||||
if exp > 0:
|
||||
return "*".join([self.paren(base)] * exp)
|
||||
else: # exp == 0
|
||||
return "1"
|
||||
|
||||
# Explicit NotImplemented functions are to prevent default sympy printing
|
||||
# behavior, which will just barf out ToFloat(...) to your IR. The error
|
||||
# message is better here because it tells you which printer class it needs
|
||||
# to go in.
|
||||
|
||||
def _print_ToFloat(self, expr):
|
||||
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
|
||||
|
||||
def _print_Infinity(self, expr):
|
||||
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
|
||||
|
||||
def _print_NegativeInfinity(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_NegativeInfinity not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
||||
|
||||
def _print_PythonMod(self, expr):
|
||||
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
|
||||
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
|
||||
|
||||
def _print_PowByNatural(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_PowByNatural not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloatPow(self, expr):
|
||||
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
|
||||
|
||||
def _print_TruncToInt(self, expr):
|
||||
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
|
||||
|
||||
def _print_RoundToInt(self, expr):
|
||||
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
|
||||
|
||||
def _print_RoundDecimal(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_RoundDecimal not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
# NB: Some float operations are INTENTIONALLY not implemented for
|
||||
# printers. You can implement them as a quick unblock, but it is better
|
||||
# to ask yourself why we haven't done this computation in the Tensor
|
||||
# universe instead
|
||||
|
||||
def _print_TruncToFloat(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_TruncToFloat not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def doprint(self, expr, *, simplify: bool = True):
|
||||
# TODO: why are people passing strings to the printer here :think:
|
||||
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
||||
|
|
@ -411,6 +484,10 @@ class ExprPrinter(Printer):
|
|||
|
||||
|
||||
class PythonPrinter(ExprPrinter):
|
||||
def _print_ToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"float({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ModularIndexing(self, expr):
|
||||
x, div, mod = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
|
|
@ -420,56 +497,72 @@ class PythonPrinter(ExprPrinter):
|
|||
x = f"({x} // {div})"
|
||||
return f"{x} % {mod}"
|
||||
|
||||
def _print_Infinity(self, expr):
|
||||
return "math.inf"
|
||||
|
||||
def _print_NegativeInfinity(self, expr):
|
||||
return "-math.inf"
|
||||
|
||||
# WARNING: this is dangerous for Triton, which has C-style modulus
|
||||
def _print_PythonMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
# WARNING: this is dangerous for Triton, which has C-style modulus
|
||||
def _print_FloorDiv(self, expr):
|
||||
x, div = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
div = self.paren(self.doprint(div))
|
||||
return f"({x} // {div})"
|
||||
|
||||
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
|
||||
# does a special algorithm
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
|
||||
def _helper_sqrt(self, expr):
|
||||
return f"math.sqrt({self._print(expr)})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
||||
return self._helper_sqrt(expr.args[0])
|
||||
|
||||
def _print_Pow(self, expr):
|
||||
# Pow() confuses triton
|
||||
def _print_FloatPow(self, expr):
|
||||
base, exp = expr.args
|
||||
# NB: Remember this is sizevar computation! You don't typically
|
||||
# expect to have to do floating point computation including exponents
|
||||
# in sizevar compute. Instead of adding support for floating
|
||||
# point pow, you should make upstream retranslate the Sympy expression
|
||||
# into Tensor expressions earlier and do that instead.
|
||||
if exp == 0.5:
|
||||
return self._helper_sqrt(base)
|
||||
elif exp == -0.5:
|
||||
return "1/" + self._helper_sqrt(base)
|
||||
base = self._print(base)
|
||||
assert exp == int(exp), exp
|
||||
exp = int(exp)
|
||||
if exp > 0:
|
||||
return "*".join([self.paren(base)] * exp)
|
||||
elif exp < 0:
|
||||
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
||||
else: # exp == 0
|
||||
return "1"
|
||||
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
||||
|
||||
# TODO: Not sure this works with Triton, even when base/exp are integral
|
||||
def _print_PowByNatural(self, expr):
|
||||
base, exp = expr.args
|
||||
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
||||
|
||||
def _print_floor(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.floor({self._print(expr.args[0])})"
|
||||
|
||||
def _print_Trunc(self, expr):
|
||||
def _print_FloorToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.floor({self._print(expr.args[0])})"
|
||||
|
||||
def _print_TruncToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
# This also could have been int(), they'll do the same thing for float
|
||||
return f"math.trunc({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ceiling(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.ceil({self._print(expr.args[0])})"
|
||||
|
||||
def _print_CeilToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.ceil({self._print(expr.args[0])})"
|
||||
|
||||
def _print_Abs(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"abs({self._print(expr.args[0])})"
|
||||
|
||||
# NB: It's expected that we've made explicit any promotion in the sympy
|
||||
# expression, so it doesn't matter that Python max/min doesn't perform
|
||||
# promotion
|
||||
def _print_Max(self, expr):
|
||||
assert len(expr.args) >= 2
|
||||
return f"max({', '.join(map(self._print, expr.args))})"
|
||||
|
|
@ -514,7 +607,7 @@ class PythonPrinter(ExprPrinter):
|
|||
assert len(expr.args) == 1
|
||||
return f"math.atan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_Round(self, expr):
|
||||
def _print_RoundToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"round({self._print(expr.args[0])})"
|
||||
|
||||
|
|
@ -653,6 +746,29 @@ class OpOverrides:
|
|||
)
|
||||
return ops.where(cond, ops.add(r, b), r)
|
||||
|
||||
@staticmethod
|
||||
def trunc_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.trunc(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def floor_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.floor(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def ceil_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.ceil(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def round_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.round(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def int_truediv(a, b):
|
||||
# TODO: this is wrong
|
||||
# TODO: an easy bandaid is to generate runtime asserts that it's
|
||||
# <= 2**53, which is when this equation is correct
|
||||
return ops.truediv(a, b)
|
||||
|
||||
@staticmethod
|
||||
def load_seed(name, offset):
|
||||
return ops.load(name, sympy.Integer(offset))
|
||||
|
|
|
|||
|
|
@ -275,11 +275,11 @@ def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length:
|
|||
|
||||
original_index = index
|
||||
|
||||
div = sympy.Wild("divisor")
|
||||
div = sympy.Wild("divisor", integer=True)
|
||||
if index.has(FloorDiv):
|
||||
index = index.replace(FloorDiv(var, div), visit_indexing_div)
|
||||
|
||||
mod = sympy.Wild("modulus")
|
||||
mod = sympy.Wild("modulus", integer=True)
|
||||
if index.has(ModularIndexing):
|
||||
index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing)
|
||||
|
||||
|
|
|
|||
|
|
@ -100,10 +100,53 @@ class CppPrinter(ExprPrinter):
|
|||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_Trunc(self, expr):
|
||||
def _print_FloorToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_TruncToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::trunc({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
return f"static_cast<{INDEX_TYPE}>({r})"
|
||||
|
||||
def _print_TruncToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::trunc({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"static_cast<double>({self._print(expr.args[0])})"
|
||||
|
||||
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
||||
# tickle though, as the inputs are typically positive (and if we can prove
|
||||
# they are positive, we will have used Mod instead, for which this codegen
|
||||
# is right).
|
||||
def _print_PythonMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_CMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
# TODO: This is only accurate up to 2**53
|
||||
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
|
||||
|
||||
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
||||
# use std::pow, that operates on floats
|
||||
def _print_PowByNatural(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_PowByNatural not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloatTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
|
||||
def _print_FloatPow(self, expr):
|
||||
base, exp = expr.args
|
||||
return f"std::pow({self._print(base)}, {self._print(exp)})"
|
||||
|
||||
def _print_Pow(self, expr):
|
||||
# Uses float constants to perform FP div
|
||||
|
|
@ -139,6 +182,11 @@ class CppPrinter(ExprPrinter):
|
|||
r = f"std::ceil({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_CeilToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::ceil({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_Min(self, expr):
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
|
|
@ -200,8 +248,9 @@ class CppPrinter(ExprPrinter):
|
|||
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
||||
return f"std::sqrt({self._print(expr.args[0])})"
|
||||
|
||||
def _print_Round(self, expr):
|
||||
def _print_RoundToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
# TODO: dispatch to llrint depending on index type
|
||||
return f"std::lrint({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundDecimal(self, expr):
|
||||
|
|
|
|||
|
|
@ -272,23 +272,68 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]):
|
|||
return f"{value}[{', '.join(expand)}]"
|
||||
|
||||
|
||||
# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
|
||||
# number of operators which Triton "implements", but in a way that is
|
||||
# inconsistent with Python semantics (and consistent with C semantics). We
|
||||
# must override all of these, or it is potential silent correctness problem
|
||||
class TritonPrinter(PythonPrinter):
|
||||
def _print_TruncToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_ToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)"
|
||||
|
||||
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
||||
# tickle though, as the inputs are typically positive (and if we can prove
|
||||
# they are positive, we will have used Mod instead, for which this codegen
|
||||
# is right). If you are trying to hit this, maybe try something like
|
||||
# torch.arange(n, device="cuda") - 1 and then do a modulus on it
|
||||
def _print_PythonMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
# TODO: This is wrong, see
|
||||
# https://github.com/triton-lang/triton/issues/955
|
||||
# But for Sympy expressions, things will /mostly/ work out because we
|
||||
# don't usually deal with negative numbers in the division
|
||||
def _print_FloorDiv(self, expr):
|
||||
assert expr.is_integer
|
||||
x, div = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
div = self.paren(self.doprint(div))
|
||||
return f"({x} // {div})"
|
||||
|
||||
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
|
||||
# precision algorithm, which we would need to replicate here
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
|
||||
# NB: sympy.floor/ceiling produce integers, so we have to do the
|
||||
# conversion to index dtype
|
||||
def _print_floor(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_Trunc(self, expr):
|
||||
def _print_FloorToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_ceiling(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
|
||||
def _print_CeilToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
|
||||
def _helper_sqrt(self, expr):
|
||||
return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))"
|
||||
|
||||
|
|
@ -359,20 +404,9 @@ class TritonPrinter(PythonPrinter):
|
|||
assert len(expr.args) == 1
|
||||
return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
if expr.is_integer:
|
||||
return super()._print_FloorDiv(expr)
|
||||
|
||||
x, div = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
div = self.paren(self.doprint(div))
|
||||
return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})"
|
||||
|
||||
def _print_Round(self, expr):
|
||||
def _print_RoundToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
return f"libdevice.llrint({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundDecimal(self, expr):
|
||||
assert len(expr.args) == 2
|
||||
|
|
|
|||
|
|
@ -1196,8 +1196,11 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
elif is_magic_method(n.target):
|
||||
# TODO: this is sus, it probably should be handled in the
|
||||
# lowerings themselves similarly to sym_size/sym-stride
|
||||
# https://github.com/pytorch/pytorch/issues/127789
|
||||
debug("is_magic_method")
|
||||
if isinstance(n.meta["val"], torch.SymInt):
|
||||
if isinstance(
|
||||
n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
|
||||
):
|
||||
result = n.meta["val"].node.expr
|
||||
else:
|
||||
result = super().run_node(n)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ from torch._prims_common import (
|
|||
is_boolean_dtype,
|
||||
is_float_dtype,
|
||||
make_channels_last_strides_for,
|
||||
make_contiguous_strides_for,
|
||||
StrideType,
|
||||
)
|
||||
from torch._subclasses.fake_tensor import get_schema_info
|
||||
|
|
@ -236,7 +235,7 @@ def ir_node_to_tensor(x, guard_shape=True):
|
|||
if is_storage_and_layout(x):
|
||||
stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc]
|
||||
else:
|
||||
stride = make_contiguous_strides_for(size) # type: ignore[arg-type]
|
||||
stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type]
|
||||
dtype = x.get_dtype()
|
||||
device = x.get_device()
|
||||
size = convert_shape_to_symint(size)
|
||||
|
|
@ -2766,6 +2765,7 @@ class FlexibleLayout(Layout):
|
|||
|
||||
allow_indexing = False
|
||||
|
||||
# WARNING! This doesn't handle zero size tensors correctly
|
||||
@staticmethod
|
||||
def contiguous_strides(sizes):
|
||||
if len(sizes) == 0:
|
||||
|
|
@ -5915,7 +5915,7 @@ def _prepare_convolution_fusion_create(
|
|||
# To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last.
|
||||
dynamic_shapes = not all(isinstance(i, int) for i in (output_size))
|
||||
if dynamic_shapes and is_contiguous_storage_and_layout(x):
|
||||
output_stride = make_contiguous_strides_for(output_size)
|
||||
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
||||
else:
|
||||
output_stride = make_channels_last_strides_for(output_size)
|
||||
|
||||
|
|
@ -5967,7 +5967,7 @@ def _prepare_linear_fusion_create(
|
|||
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
|
||||
inputs = [x, weight]
|
||||
|
||||
output_stride = make_contiguous_strides_for(output_size)
|
||||
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
||||
kernel_layout = FixedLayout(
|
||||
x.get_device(),
|
||||
x.get_dtype(),
|
||||
|
|
@ -6283,7 +6283,7 @@ class MKLPackedLinear(ExternKernelAlloc):
|
|||
*m, _ = x.get_size()
|
||||
oc, _ = orig_w.get_size()
|
||||
output_size = list(m) + [oc]
|
||||
output_stride = make_contiguous_strides_for(output_size)
|
||||
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
||||
inputs = [x, packed_w, orig_w]
|
||||
constant_args = [batch_size]
|
||||
if B is not None:
|
||||
|
|
@ -6601,13 +6601,13 @@ class MkldnnRnnLayer(ExternKernelAlloc):
|
|||
|
||||
def get_strides_of_lstm_output(output_shape, batch_first):
|
||||
assert len(output_shape) == 3, "Expect output_shape to be 3D"
|
||||
return make_contiguous_strides_for(output_shape)
|
||||
return FlexibleLayout.contiguous_strides(output_shape)
|
||||
|
||||
output_sizes = [output_shape, hy_shape, cy_shape]
|
||||
output_strides = [
|
||||
get_strides_of_lstm_output(output_shape, batch_first),
|
||||
make_contiguous_strides_for(hy_shape),
|
||||
make_contiguous_strides_for(cy_shape),
|
||||
FlexibleLayout.contiguous_strides(hy_shape),
|
||||
FlexibleLayout.contiguous_strides(cy_shape),
|
||||
]
|
||||
output_ir = [
|
||||
MultiOutput(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from enum import auto, Enum
|
|||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch._prims_common import make_contiguous_strides_for
|
||||
from .. import config
|
||||
from ..ir import (
|
||||
ComputedBuffer,
|
||||
|
|
@ -389,7 +388,7 @@ def flex_attention(*args, **kwargs):
|
|||
query.get_device(),
|
||||
query.get_dtype(),
|
||||
query.get_size(),
|
||||
make_contiguous_strides_for(query.get_size()),
|
||||
FlexibleLayout.contiguous_strides(query.get_size()),
|
||||
)
|
||||
# see NOTE:[TritonTemplates with multiple outputs]
|
||||
logsumexp_shape = query.get_size()[:-1] # [B, H, M]
|
||||
|
|
@ -745,7 +744,7 @@ def flex_attention_backward(*args, **kwargs):
|
|||
key.get_device(),
|
||||
key.get_dtype(),
|
||||
key.get_size(),
|
||||
make_contiguous_strides_for(key.get_size()),
|
||||
FlexibleLayout.contiguous_strides(key.get_size()),
|
||||
)
|
||||
|
||||
# Create delta which will is needed for the bwd's kernel
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from torch._prims_common import (
|
|||
Number,
|
||||
)
|
||||
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing
|
||||
from .._dynamo.utils import import_submodule
|
||||
|
||||
from . import config, inductor_prims, ir, test_operators # NOQA: F401
|
||||
|
|
@ -4262,7 +4262,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim):
|
|||
out_sz = out_sz[dim]
|
||||
in_sz = in_sz[dim]
|
||||
kernel_sz = kernel_sz[dim]
|
||||
alpha = (in_sz - kernel_sz) / (out_sz - 1)
|
||||
alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1)
|
||||
samples_loader = samples.make_loader()
|
||||
|
||||
def load(prefix, i):
|
||||
|
|
@ -4372,7 +4372,7 @@ def upsample_nearest2d_backward(
|
|||
w_kernel_max = ceildiv(inp_w, out_w)
|
||||
|
||||
def start_index(index, out_dim, inp_dim):
|
||||
return CeilDiv(index * inp_dim, out_dim)
|
||||
return CeilDiv(index * inp_dim, sympy.sympify(out_dim))
|
||||
|
||||
def end_index(index, out_dim, inp_dim):
|
||||
return start_index((index + 1), out_dim, inp_dim)
|
||||
|
|
|
|||
|
|
@ -138,6 +138,38 @@ class OpsHandler(Protocol[T]):
|
|||
"""
|
||||
...
|
||||
|
||||
def trunc_to_int(self, x: T, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Convert x to dtype with truncation semantics (similar to how the int
|
||||
constructor works in Python). In Inductor codegen, this just decays
|
||||
to trunc and then to_dtype, but this composite operation helps
|
||||
roundtrips for Sympy evaluation.
|
||||
|
||||
dtype is taken as an explicit parameter because the desired output
|
||||
dtype is typically the index dtype, which may vary between int32 and
|
||||
int64 depending on if we've shown that all the indexing operations can
|
||||
be done in int32.
|
||||
"""
|
||||
...
|
||||
|
||||
def ceil_to_int(self, x: T, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Convert x to dtype with ceiling semantics. See also trunc_to_int.
|
||||
"""
|
||||
...
|
||||
|
||||
def floor_to_int(self, x: T, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Convert x to dtype with ceiling semantics. See also trunc_to_int.
|
||||
"""
|
||||
...
|
||||
|
||||
def round_to_int(self, x: T, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Convert x to dtype with round-to-even semantics. See also trunc_to_int.
|
||||
"""
|
||||
...
|
||||
|
||||
def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
|
||||
|
|
@ -398,21 +430,23 @@ class OpsHandler(Protocol[T]):
|
|||
def isnan(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
# This rounds half to even to break ties
|
||||
def round(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
def floor(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
def sign(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
def to_int(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
def trunc(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
def ceil(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
|
|
@ -449,6 +483,7 @@ class OpsHandler(Protocol[T]):
|
|||
def mul(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
def pow(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
|
||||
|
|
@ -617,14 +652,21 @@ class OpsHandler(Protocol[T]):
|
|||
|
||||
def floordiv(self, x0: T, x1: T) -> T:
|
||||
"""Python-style floor division between integers only. Computes the
|
||||
true division of two numbers and floors the result.
|
||||
true division of two numbers and floors the result. If you want
|
||||
floor division for floats, do regular truediv and floor the result.
|
||||
"""
|
||||
...
|
||||
|
||||
def truediv(self, x0: T, x1: T) -> T:
|
||||
"""True division between floats. Integer inputs are NOT valid: to do
|
||||
Python style (int, int) -> float division, promote the inputs to float
|
||||
first."""
|
||||
"""True division between floats. Integer inputs are NOT valid. To
|
||||
do Python-style (int, int) -> float division, use int_truediv"""
|
||||
...
|
||||
|
||||
def int_truediv(self, x0: T, x1: T) -> T:
|
||||
"""True division between integers. This is NOT the same as promoting
|
||||
to float and doing integer division, there is a bespoke algorithm for
|
||||
doing the division in higher precision than the above.
|
||||
"""
|
||||
...
|
||||
|
||||
def div(self, x0: T, x1: T) -> T:
|
||||
|
|
@ -640,6 +682,10 @@ class OpsHandler(Protocol[T]):
|
|||
"""Python-style modulus, take sign from RHS (x1)."""
|
||||
...
|
||||
|
||||
def round_decimal(self, x0: T, x1: T) -> T:
|
||||
"""Python-style round with decimal argument"""
|
||||
...
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# In CUDA, optimized implementations of other mathematical operations are
|
||||
# offered separately via libdevice for double precision computation (in
|
||||
|
|
|
|||
|
|
@ -386,7 +386,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
assert isinstance(mask, (str, type(None)))
|
||||
assert self.template_mask is None
|
||||
indices = list(map(TritonPrinter.paren, indices))
|
||||
index_symbols = [sympy.Symbol(x) for x in indices]
|
||||
index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
|
||||
lengths = [
|
||||
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
|
||||
]
|
||||
|
|
@ -410,7 +410,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
|
||||
output_index = self.rename_indexing(output_index)
|
||||
if output_index == contiguous_index:
|
||||
output_index = sympy.Symbol("xindex")
|
||||
output_index = sympy.Symbol("xindex", integer=True)
|
||||
|
||||
epilogue_args = [val]
|
||||
for input_node in itertools.chain(
|
||||
|
|
|
|||
|
|
@ -161,9 +161,9 @@ class SizeVarAllocator:
|
|||
if expr.has(ModularIndexing):
|
||||
expr = expr.replace(
|
||||
ModularIndexing(
|
||||
sympy.Wild("base"),
|
||||
sympy.Wild("divisor"),
|
||||
sympy.Wild("modulus"),
|
||||
sympy.Wild("base", integer=True),
|
||||
sympy.Wild("divisor", integer=True),
|
||||
sympy.Wild("modulus", integer=True),
|
||||
),
|
||||
visit_modular_indexing,
|
||||
)
|
||||
|
|
@ -171,8 +171,8 @@ class SizeVarAllocator:
|
|||
if expr.has(FloorDiv):
|
||||
expr = expr.replace(
|
||||
FloorDiv(
|
||||
sympy.Wild("base"),
|
||||
sympy.Wild("divisor"),
|
||||
sympy.Wild("base", integer=True),
|
||||
sympy.Wild("divisor", integer=True),
|
||||
),
|
||||
visit_indexing_div,
|
||||
)
|
||||
|
|
@ -604,11 +604,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr:
|
|||
"""
|
||||
assert isinstance(expr, sympy.Add)
|
||||
|
||||
scale = sympy.Wild("scale", exclude=[0])
|
||||
base = sympy.Wild("base")
|
||||
divisor = sympy.Wild("divisor")
|
||||
mod1 = sympy.Wild("modulus")
|
||||
mod2 = sympy.Wild("modulus2")
|
||||
scale = sympy.Wild("scale", exclude=[0], integer=True)
|
||||
base = sympy.Wild("base", integer=True)
|
||||
divisor = sympy.Wild("divisor", integer=True)
|
||||
mod1 = sympy.Wild("modulus", integer=True)
|
||||
mod2 = sympy.Wild("modulus2", integer=True)
|
||||
for term1 in expr.args:
|
||||
m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
|
||||
if m1:
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ def ceildiv(
|
|||
numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
|
||||
) -> Union[int, sympy.Expr]:
|
||||
if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
|
||||
return CeilDiv(numer, denom)
|
||||
return CeilDiv(sympy.sympify(numer), sympy.sympify(denom))
|
||||
# TODO: There is a bug in a call to this function, to repro:
|
||||
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
|
||||
# --amp --only YituTechConvBert --dynamic-shapes
|
||||
|
|
|
|||
|
|
@ -1727,7 +1727,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
for run_impl_check, op_impl in op_implementations_checks:
|
||||
if run_impl_check(func):
|
||||
op_impl_out = op_impl(self, func, *args, **kwargs)
|
||||
if op_impl_out != NotImplemented:
|
||||
if op_impl_out is not NotImplemented:
|
||||
return maybe_propagate_real_tensors(op_impl_out)
|
||||
|
||||
def maybe_run_unsafe_fallback(error=None):
|
||||
|
|
|
|||
|
|
@ -1200,8 +1200,13 @@ void initJITBindings(PyObject* module) {
|
|||
SYMNODE_BINARY(sub)
|
||||
SYMNODE_BINARY(mul)
|
||||
SYMNODE_BINARY(truediv)
|
||||
SYMNODE_BINARY(int_truediv)
|
||||
SYMNODE_BINARY(float_truediv)
|
||||
SYMNODE_BINARY(pow)
|
||||
SYMNODE_BINARY(float_pow)
|
||||
SYMNODE_BINARY(pow_by_natural)
|
||||
SYMNODE_BINARY(floordiv)
|
||||
SYMNODE_BINARY(int_floordiv)
|
||||
SYMNODE_BINARY(mod)
|
||||
SYMNODE_BINARY(eq)
|
||||
SYMNODE_BINARY(ne)
|
||||
|
|
|
|||
|
|
@ -198,14 +198,34 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
|
|||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode float_truediv(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode int_truediv(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode pow(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode float_pow(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode pow_by_natural(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode floordiv(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode int_floordiv(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode mod(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import builtins
|
||||
import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import sys
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
|
|
@ -254,11 +253,14 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
|
|||
shared: Optional[_ConstraintTarget] = None
|
||||
debug_name: Optional[str] = None
|
||||
|
||||
def _clone_with_range(self, lower=0, upper=math.inf):
|
||||
def _clone_with_range(self, lower=0, upper=None):
|
||||
# Import sympy locally
|
||||
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
if upper is None:
|
||||
upper = sys.maxsize - 1
|
||||
|
||||
constraint_range = StrictMinMaxConstraint(
|
||||
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
|
||||
warn_only=False,
|
||||
|
|
@ -486,7 +488,6 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
|
|||
)
|
||||
|
||||
# Import sympy locally
|
||||
import sympy
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
|
@ -496,7 +497,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
|
|||
id(t),
|
||||
index,
|
||||
StrictMinMaxConstraint(
|
||||
vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False
|
||||
vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False
|
||||
),
|
||||
debug_name=debug_name,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -277,7 +277,13 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
|
|||
raise
|
||||
|
||||
except Exception:
|
||||
log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs)
|
||||
log.error( # noqa: G201
|
||||
"failed while running %s(*%s, **%s)",
|
||||
name,
|
||||
args[1:],
|
||||
kwargs,
|
||||
exc_info=log.isEnabledFor(logging.INFO),
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -267,8 +267,11 @@ class SymNode:
|
|||
def mod(self, other) -> "SymNode":
|
||||
return self._mod(other) # type: ignore[attr-defined]
|
||||
|
||||
def pow(self, other) -> "SymNode":
|
||||
return self._pow(other) # type: ignore[attr-defined]
|
||||
def float_pow(self, other) -> "SymNode":
|
||||
return self._float_pow(other) # type: ignore[attr-defined]
|
||||
|
||||
def pow_by_natural(self, other) -> "SymNode":
|
||||
return self._pow_by_natural(other) # type: ignore[attr-defined]
|
||||
|
||||
def and_(self, other) -> "SymNode":
|
||||
return self._and_(other) # type: ignore[attr-defined]
|
||||
|
|
@ -276,11 +279,14 @@ class SymNode:
|
|||
def or_(self, other) -> "SymNode":
|
||||
return self._or_(other) # type: ignore[attr-defined]
|
||||
|
||||
def truediv(self, other) -> "SymNode":
|
||||
return self._truediv(other) # type: ignore[attr-defined]
|
||||
def float_truediv(self, other) -> "SymNode":
|
||||
return self._float_truediv(other) # type: ignore[attr-defined]
|
||||
|
||||
def floordiv(self, other) -> "SymNode":
|
||||
return self._floordiv(other) # type: ignore[attr-defined]
|
||||
def int_truediv(self, other) -> "SymNode":
|
||||
return self._int_truediv(other) # type: ignore[attr-defined]
|
||||
|
||||
def int_floordiv(self, other) -> "SymNode":
|
||||
return self._int_floordiv(other) # type: ignore[attr-defined]
|
||||
|
||||
def lshift(self, other) -> "SymNode":
|
||||
return self._lshift(other) # type: ignore[attr-defined]
|
||||
|
|
@ -361,6 +367,17 @@ class SymNode:
|
|||
def sym_and(self, other):
|
||||
return self.and_(other)
|
||||
|
||||
# There is no int_truediv available from C++
|
||||
def truediv(self, other):
|
||||
return self.float_truediv(other)
|
||||
|
||||
def floordiv(self, other) -> "SymNode":
|
||||
return self.int_floordiv(other)
|
||||
|
||||
# We didn't bind integer pow in C++
|
||||
def pow(self, other):
|
||||
return self.float_pow(other)
|
||||
|
||||
def is_non_overlapping_and_dense(self, sizes, strides):
|
||||
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
|
||||
|
||||
|
|
@ -477,7 +494,7 @@ METHOD_TO_OPERATOR = {
|
|||
"eq": operator.eq,
|
||||
"floor": math.floor,
|
||||
"trunc": math.trunc,
|
||||
"floordiv": operator.floordiv,
|
||||
"int_floordiv": operator.floordiv,
|
||||
"ge": operator.ge,
|
||||
"gt": operator.gt,
|
||||
"is_integer": lambda x: x.is_integer(),
|
||||
|
|
@ -489,7 +506,8 @@ METHOD_TO_OPERATOR = {
|
|||
"ne": operator.ne,
|
||||
"neg": operator.neg,
|
||||
"or": operator.or_,
|
||||
"pow": operator.pow,
|
||||
"float_pow": operator.pow,
|
||||
"pow_by_natural": operator.pow,
|
||||
"round": builtins.round,
|
||||
"rshift": operator.rshift,
|
||||
"sub": operator.sub,
|
||||
|
|
@ -498,12 +516,14 @@ METHOD_TO_OPERATOR = {
|
|||
"sym_max": sym_max,
|
||||
"sym_min": sym_min,
|
||||
"sym_not": sym_not,
|
||||
"truediv": operator.truediv,
|
||||
"float_truediv": operator.truediv,
|
||||
"int_truediv": operator.truediv,
|
||||
}
|
||||
|
||||
unary_magic_methods = {
|
||||
"abs",
|
||||
"sym_float",
|
||||
"sym_int",
|
||||
"ceil",
|
||||
"floor",
|
||||
"neg",
|
||||
|
|
@ -559,20 +579,20 @@ also_bool_magic_methods = {"eq"}
|
|||
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
|
||||
|
||||
# Methods that are only for float
|
||||
only_float_magic_methods = {"is_integer"}
|
||||
only_float_magic_methods = {"is_integer", "round", "sym_int"}
|
||||
|
||||
|
||||
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
|
||||
|
||||
|
||||
always_float_magic_methods = {"truediv", "sym_float", "pow"}
|
||||
always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
|
||||
|
||||
for name in math_op_names:
|
||||
sym_name = f"sym_{name}"
|
||||
always_float_magic_methods.add(sym_name)
|
||||
|
||||
|
||||
always_int_magic_methods = {"ceil", "floor", "trunc"}
|
||||
always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
|
||||
always_bool_magic_methods = {
|
||||
"eq",
|
||||
"ne",
|
||||
|
|
@ -590,10 +610,16 @@ always_bool_magic_methods = {
|
|||
# Methods that have a `__foo__` as well as `__rfoo__`
|
||||
|
||||
|
||||
def _sympy_truediv(a, b):
|
||||
from torch.utils._sympy.functions import TrueDiv
|
||||
def _sympy_float_truediv(a, b):
|
||||
from torch.utils._sympy.functions import FloatTrueDiv
|
||||
|
||||
return TrueDiv(a, b)
|
||||
return FloatTrueDiv(a, b)
|
||||
|
||||
|
||||
def _sympy_int_truediv(a, b):
|
||||
from torch.utils._sympy.functions import IntTrueDiv
|
||||
|
||||
return IntTrueDiv(a, b)
|
||||
|
||||
|
||||
def _sympy_floordiv(a, b):
|
||||
|
|
@ -603,15 +629,24 @@ def _sympy_floordiv(a, b):
|
|||
|
||||
|
||||
def _sympy_mod(a, b):
|
||||
from torch.utils._sympy.functions import Mod
|
||||
from torch.utils._sympy.functions import Mod, PythonMod
|
||||
|
||||
return Mod(a, b)
|
||||
if a.is_nonnegative and b.is_nonnegative:
|
||||
return Mod(a, b)
|
||||
else:
|
||||
return PythonMod(a, b)
|
||||
|
||||
|
||||
def _sympy_pow(a, b):
|
||||
from torch.utils._sympy.functions import Pow
|
||||
def _sympy_pow_by_natural(a, b):
|
||||
from torch.utils._sympy.functions import PowByNatural
|
||||
|
||||
return Pow(a, b)
|
||||
return PowByNatural(a, b)
|
||||
|
||||
|
||||
def _sympy_float_pow(a, b):
|
||||
from torch.utils._sympy.functions import FloatPow
|
||||
|
||||
return FloatPow(a, b)
|
||||
|
||||
|
||||
def _sympy_and(a, b):
|
||||
|
|
@ -643,11 +678,13 @@ reflectable_magic_methods = {
|
|||
"sub": operator.sub,
|
||||
"mul": operator.mul,
|
||||
"mod": _sympy_mod,
|
||||
"pow": _sympy_pow,
|
||||
"pow_by_natural": _sympy_pow_by_natural,
|
||||
"float_pow": _sympy_float_pow,
|
||||
"and": _sympy_and,
|
||||
"or": _sympy_or,
|
||||
"truediv": _sympy_truediv,
|
||||
"floordiv": _sympy_floordiv,
|
||||
"float_truediv": _sympy_float_truediv,
|
||||
"int_truediv": _sympy_int_truediv,
|
||||
"int_floordiv": _sympy_floordiv,
|
||||
"lshift": _sympy_lshift,
|
||||
"rshift": _sympy_rshift,
|
||||
}
|
||||
|
|
@ -672,21 +709,23 @@ def _floor_ceil_helper(a, fn):
|
|||
|
||||
|
||||
def _sympy_floor(a):
|
||||
import sympy
|
||||
from torch.utils._sympy.functions import FloorToInt
|
||||
|
||||
return _floor_ceil_helper(a, sympy.floor)
|
||||
return FloorToInt(a)
|
||||
|
||||
|
||||
# NB: this is Python trunc semantics which returns an int. Do NOT use this to
|
||||
# represent torch.trunc (which is float to float)
|
||||
def _sympy_trunc(a):
|
||||
from torch.utils._sympy.functions import Trunc
|
||||
from torch.utils._sympy.functions import TruncToInt
|
||||
|
||||
return Trunc(a)
|
||||
return TruncToInt(a)
|
||||
|
||||
|
||||
def _sympy_ceil(a):
|
||||
import sympy
|
||||
from torch.utils._sympy.functions import CeilToInt
|
||||
|
||||
return _floor_ceil_helper(a, sympy.ceiling)
|
||||
return CeilToInt(a)
|
||||
|
||||
|
||||
def _sympy_eq(a, b):
|
||||
|
|
@ -771,26 +810,28 @@ def _sympy_abs(a):
|
|||
|
||||
|
||||
def _sympy_round(number, ndigits=None):
|
||||
from torch.utils._sympy.functions import Round, RoundDecimal
|
||||
from torch.utils._sympy.functions import RoundDecimal, RoundToInt
|
||||
|
||||
if ndigits is None:
|
||||
return Round(number)
|
||||
return RoundToInt(number)
|
||||
else:
|
||||
return RoundDecimal(number, ndigits)
|
||||
|
||||
|
||||
def _sympy_sym_float(a):
|
||||
# Cannot use sympy.Float(a) here, coz it expects python literals
|
||||
# Multiply by 1.0 to cast to float. This is needed when the input
|
||||
# is a SymInt which has the assumption that it is integer and
|
||||
# SymPy will otherwise assume that return value cannot be a float.
|
||||
return a * 1.0
|
||||
from torch.utils._sympy.functions import ToFloat
|
||||
|
||||
# NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
|
||||
# reports that it is an integer
|
||||
return ToFloat(a)
|
||||
|
||||
|
||||
def _sympy_is_integer(a):
|
||||
import sympy
|
||||
|
||||
return sympy.Eq(sympy.floor(a), a)
|
||||
from torch.utils._sympy.functions import ToFloat
|
||||
|
||||
return sympy.Eq(ToFloat(sympy.floor(a)), a)
|
||||
|
||||
|
||||
magic_methods = {
|
||||
|
|
@ -989,9 +1030,26 @@ def _make_node_magic(method, func):
|
|||
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
||||
)
|
||||
assert isinstance(other, SymNode)
|
||||
# TODO: consider constant prop here
|
||||
try:
|
||||
out = func(self.expr, other.expr)
|
||||
if method == "mod":
|
||||
from torch.utils._sympy.functions import Mod, PythonMod
|
||||
|
||||
# Special handling for mod that requires access to the value
|
||||
# ranges
|
||||
shape_env = self.shape_env
|
||||
if (
|
||||
self.expr.is_nonnegative
|
||||
or shape_env.bound_sympy(self.expr).lower >= 0
|
||||
) and (
|
||||
other.expr.is_nonnegative
|
||||
or shape_env.bound_sympy(other.expr).lower >= 0
|
||||
):
|
||||
out = Mod(self.expr, other.expr)
|
||||
else:
|
||||
out = PythonMod(self.expr, other.expr)
|
||||
else:
|
||||
# TODO: consider constant prop here
|
||||
out = func(self.expr, other.expr)
|
||||
except Exception:
|
||||
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
|
||||
raise
|
||||
|
|
@ -1122,9 +1180,13 @@ def _make_node_magic(method, func):
|
|||
except Exception:
|
||||
log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
|
||||
raise
|
||||
|
||||
out = safe_expand(out)
|
||||
|
||||
pytype = int if ndigits is None else self.pytype
|
||||
if ndigits is None:
|
||||
pytype = int
|
||||
else:
|
||||
pytype = self.pytype
|
||||
|
||||
out_hint = None
|
||||
if self.hint is not None:
|
||||
|
|
@ -1136,6 +1198,7 @@ def _make_node_magic(method, func):
|
|||
# hack down below works, because all round function down the line all take ndigits=None as default in their
|
||||
# signature.
|
||||
# TODO: Remove the args construction below if a different sentinel is used by FX.
|
||||
# ezyang(May 2024): LOL
|
||||
args = [self.fx_node]
|
||||
if ndigits is not None:
|
||||
args.append(ndigits)
|
||||
|
|
@ -1259,6 +1322,32 @@ def _make_user_magic(method, user_type):
|
|||
return x.node.is_constant()
|
||||
return False
|
||||
|
||||
# Promotion rules for binary operations. NB: we preserve PYTHON semantics
|
||||
# - if args are same type, do nothing
|
||||
# - if one arg is float, promote other arg to float
|
||||
# - nb: this applies to floordiv, even though output is integral
|
||||
# (it's still float)
|
||||
# - pow is funny business
|
||||
# - if both ints
|
||||
# - trigger a guard on exponent >= 0
|
||||
# - if non-negative, output is int
|
||||
# - otherwise, output is float
|
||||
# - otherwise, promote other arg to float
|
||||
# - nb: complex is impossible to handle correctly lol, with
|
||||
# negative base and integral float need to diverge semantics and
|
||||
# just always return complex. Neener neener pretend this problem
|
||||
# doesn't exist
|
||||
# - equality is pain: Python does the fancy thing where it unpacks the
|
||||
# mantissa from the float and then compares that against the int.
|
||||
# Which means it is able to tell that
|
||||
# 9007199254740993 != 9007199254740992. (rather than if the LHS was
|
||||
# promoted to float, in which case it would have truncated to the RHS
|
||||
# and subsequently been equal). We'll model this exactly by having
|
||||
# special mixed type equality operations. Unfortunately, we need to
|
||||
# do this for all comparison operations (maybe I'll only implement
|
||||
# compare)
|
||||
# - sym_ite mumble mumble really shouldn't allow mixed but whatever
|
||||
|
||||
if method in bool_becomes_int_magic_methods:
|
||||
|
||||
def promote(x):
|
||||
|
|
@ -1272,6 +1361,41 @@ def _make_user_magic(method, user_type):
|
|||
def promote(x):
|
||||
return x
|
||||
|
||||
def promote2(self, other):
|
||||
# TODO: Remove eq and other relations from this list.
|
||||
# CPython has fancy implementations for these to get as much precision
|
||||
# as possible instead of just promoting to float64 and praying, so we
|
||||
# need to handle them specially too.
|
||||
# Also, note that int_truediv doesn't go through this path: both
|
||||
# arguments are "int" so there isn't any promotion
|
||||
if method not in [
|
||||
"add",
|
||||
"sub",
|
||||
"mul",
|
||||
"mod",
|
||||
"float_pow",
|
||||
"float_truediv",
|
||||
"int_floordiv",
|
||||
"sym_min",
|
||||
"sym_max",
|
||||
# TODO: remove these
|
||||
"eq",
|
||||
"ne",
|
||||
"gt",
|
||||
"lt",
|
||||
"le",
|
||||
"ge",
|
||||
]:
|
||||
return self, other
|
||||
f_self = isinstance(self, (float, torch.SymFloat))
|
||||
f_other = isinstance(other, (float, torch.SymFloat))
|
||||
if f_self or f_other:
|
||||
if not f_self:
|
||||
self = torch.sym_float(self)
|
||||
if not f_other:
|
||||
other = torch.sym_float(other)
|
||||
return self, other
|
||||
|
||||
# Before and after performing the operation, check if any operands are constant.
|
||||
# If so, extract out the constant values first. If `self` itself is a
|
||||
# constant, then "redispatch" by calling back into the operator. Sometimes
|
||||
|
|
@ -1286,9 +1410,12 @@ def _make_user_magic(method, user_type):
|
|||
return wrap_node(getattr(self.node, method_attr)())
|
||||
|
||||
def binary_magic_impl(self, other):
|
||||
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
|
||||
return NotImplemented
|
||||
sym_node_log.debug("MAGIC %s %s %s", method, self, other)
|
||||
self = promote(self)
|
||||
other = promote(other)
|
||||
self, other = promote2(self, other)
|
||||
if is_constant(self):
|
||||
return (method_to_operator(method))(get_constant(self), other)
|
||||
if is_constant(other):
|
||||
|
|
@ -1300,8 +1427,11 @@ def _make_user_magic(method, user_type):
|
|||
return get_constant(ret) if is_constant(ret) else ret
|
||||
|
||||
def rbinary_magic_impl(self, other):
|
||||
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
|
||||
return NotImplemented
|
||||
self = promote(self)
|
||||
other = promote(other)
|
||||
self, other = promote2(self, other)
|
||||
if is_constant(self):
|
||||
return (method_to_operator(method))(get_constant(self), other)
|
||||
if is_constant(other):
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ from torch._logging import trace_structured, structured
|
|||
from torch import SymBool, SymFloat, SymInt
|
||||
from torch._guards import ShapeGuard, Source, TracingContext
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv
|
||||
from torch.utils._sympy.solve import try_solve
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
|
|
@ -869,9 +869,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
|||
for N=1.
|
||||
"""
|
||||
if min is None:
|
||||
min = -sympy.oo
|
||||
min = -sys.maxsize - 1
|
||||
if max is None:
|
||||
max = sympy.oo
|
||||
max = sys.maxsize - 1
|
||||
|
||||
if max < min:
|
||||
raise ValueError(
|
||||
|
|
@ -979,16 +979,6 @@ def eval_guards(gm, *args, ignore_static=True):
|
|||
def bind_symbols(gm, *args):
|
||||
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
|
||||
|
||||
def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges):
|
||||
"""
|
||||
We assert that the bounds are either Boolean, or not finite, or can be computed
|
||||
in exact prevision via rational arithmetic.
|
||||
The only exception to this is the rare case when the user calls `sqrt(s0)`
|
||||
sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still)
|
||||
"""
|
||||
assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr)
|
||||
assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr)
|
||||
|
||||
class DimDynamic(Enum):
|
||||
"""
|
||||
Controls how to perform symbol allocation for a dimension. It is always
|
||||
|
|
@ -1387,14 +1377,19 @@ SYMPY_INTERP = {
|
|||
'Min': min,
|
||||
'Max': max,
|
||||
'Mod': operator.mod,
|
||||
'PythonMod': operator.mod,
|
||||
'FloorDiv': operator.floordiv,
|
||||
'TrueDiv': operator.truediv,
|
||||
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
|
||||
'floor': math.floor,
|
||||
'ceiling': math.ceil,
|
||||
'FloorToInt': math.floor,
|
||||
'CeilToInt': math.ceil,
|
||||
'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
|
||||
'Round': builtins.round,
|
||||
'RoundToInt': builtins.round,
|
||||
'RoundDecimal': builtins.round,
|
||||
'TruncToInt': math.trunc,
|
||||
'IntTrueDiv': operator.truediv,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1642,10 +1637,17 @@ class DimConstraints:
|
|||
congruence = (base - mod_reduced) % divisor
|
||||
if congruence != 0:
|
||||
self._congruences[s].add(congruence)
|
||||
# NB: Must not be CleanDiv, it needs to be regular sympy division
|
||||
# so inequality solver works. This is sort of problematic for
|
||||
# is_integer tests though haha
|
||||
return (base - mod_reduced) / divisor
|
||||
|
||||
if expr.has(Mod):
|
||||
expr = expr.replace(Mod, mod_handler)
|
||||
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
|
||||
# arguments should be OK.
|
||||
if expr.has(PythonMod):
|
||||
expr = expr.replace(PythonMod, mod_handler)
|
||||
if expr.has(FloorDiv):
|
||||
expr = expr.replace(FloorDiv, floor_div_handler)
|
||||
return expr
|
||||
|
|
@ -3330,6 +3332,7 @@ class ShapeEnv:
|
|||
self.pending_fresh_unbacked_symbols.append(symbol)
|
||||
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
||||
vr = self.var_to_range[symbol] = ValueRanges.unknown()
|
||||
assert vr.is_float
|
||||
|
||||
# Create a new FX placeholder and Z3 variable for 'symbol'.
|
||||
fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
|
||||
|
|
@ -3348,6 +3351,7 @@ class ShapeEnv:
|
|||
self.counter["create_unbacked_symbol"] += 1
|
||||
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
||||
vr = self.var_to_range[symbol] = self._default_unspecified_value_range()
|
||||
assert vr.is_int
|
||||
|
||||
# Create a new FX placeholder and Z3 variable for 'symbol'.
|
||||
fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
|
||||
|
|
@ -3371,6 +3375,7 @@ class ShapeEnv:
|
|||
self.counter["create_unbacked_symbol"] += 1
|
||||
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
||||
vr = self.var_to_range[symbol] = ValueRanges(0, 1)
|
||||
assert vr.is_int
|
||||
|
||||
# Create a new FX placeholder and Z3 variable for 'symbol'.
|
||||
fx_node = self._create_fx_placeholder_and_z3var(symbol, bool)
|
||||
|
|
@ -3516,6 +3521,7 @@ class ShapeEnv:
|
|||
self.var_to_range[sympy_expr] &= constraint_dim.vr
|
||||
|
||||
vr = self.var_to_range[sympy_expr]
|
||||
assert vr.is_int
|
||||
|
||||
if val not in vr:
|
||||
raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")
|
||||
|
|
@ -3524,6 +3530,7 @@ class ShapeEnv:
|
|||
elif isinstance(val, float):
|
||||
self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo)
|
||||
range_str = f"[{vr.lower}, {vr.upper}]"
|
||||
assert vr.is_float
|
||||
else:
|
||||
# Skip var_range logic for SingletonInt
|
||||
# Only used for jagged layout nested tensors
|
||||
|
|
@ -3573,6 +3580,7 @@ class ShapeEnv:
|
|||
|
||||
def add_var_to_val(self, expr: sympy.Symbol, val: int):
|
||||
""" Adds a new symbol to the symbolic environment. """
|
||||
log.debug("add_var_to_val %s %s", expr, val, stack_info=True)
|
||||
assert expr not in self.var_to_val, f"{expr} already exists"
|
||||
self.var_to_val[expr] = sympy.Integer(val)
|
||||
|
||||
|
|
@ -4301,7 +4309,8 @@ class ShapeEnv:
|
|||
# Clamp values of size-like variables
|
||||
for x in self.size_like & var_to_range.keys():
|
||||
if var_to_range[x] is not None:
|
||||
var_to_range[x] = ValueRanges(2, sympy.oo)
|
||||
var_to_range[x] = ValueRanges(2, sys.maxsize - 1)
|
||||
assert var_to_range[x].is_int
|
||||
return bound_sympy(expr, var_to_range)
|
||||
|
||||
@_lru_cache
|
||||
|
|
@ -4418,6 +4427,11 @@ class ShapeEnv:
|
|||
vr = self._default_unspecified_value_range()
|
||||
if size_oblivious and k in self.size_like:
|
||||
lower = max(2, vr.lower)
|
||||
# This is a bit dodgy: what this means is that there was a
|
||||
# size-like unbacked symbol whose upper bound < 2. This
|
||||
# causes... problems.
|
||||
if lower <= vr.upper:
|
||||
vr = ValueRanges(lower, vr.upper)
|
||||
else:
|
||||
lower = vr.lower
|
||||
# Don't do anything if we don't have a nontrivial lower bound
|
||||
|
|
@ -4425,10 +4439,17 @@ class ShapeEnv:
|
|||
# SymInt
|
||||
if (
|
||||
lower < (-sys.maxsize - 1) // 2 or
|
||||
(unbacked_only and k in self.var_to_val)
|
||||
(unbacked_only and k in self.var_to_val) or
|
||||
not vr.is_int
|
||||
):
|
||||
new_range_env[k] = vr
|
||||
continue
|
||||
# The goal is to take our symbols which have various lower bounds
|
||||
# and reallocate them into new symbols which are exactly positive;
|
||||
# e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
|
||||
# [1, inf], where s0 = ess0 + 1. This gives the most information
|
||||
# to sympy for subsequent simplifications.
|
||||
#
|
||||
# Positive means >= 1
|
||||
# Positive - 1 means >= 0
|
||||
# Positive + lower - 1 means >= lower
|
||||
|
|
@ -4460,6 +4481,14 @@ class ShapeEnv:
|
|||
self.counter["sympy_recursion_error"] += 1
|
||||
return None
|
||||
|
||||
new_expr = safe_expand(new_expr)
|
||||
if new_expr.is_number:
|
||||
return new_expr
|
||||
|
||||
# This is bad to do, the replacement with division leaves us with
|
||||
# rationals when atom.args[0] is addition, e.g., sympy will happily
|
||||
# turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication!
|
||||
"""
|
||||
floor_div_replace = {}
|
||||
for atom in new_expr.atoms(FloorDiv):
|
||||
floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
|
||||
|
|
@ -4468,13 +4497,12 @@ class ShapeEnv:
|
|||
# are still free symbols
|
||||
if new_expr.is_number:
|
||||
return new_expr
|
||||
"""
|
||||
|
||||
# Check if the range can solve it statically
|
||||
out = bound_sympy(new_expr, new_range_env)
|
||||
if expect_rational:
|
||||
_assert_bound_is_rational(new_expr, out)
|
||||
if out.is_singleton():
|
||||
return out.lower
|
||||
if out.is_singleton():
|
||||
return out.lower
|
||||
|
||||
return new_expr if unbacked_only else None
|
||||
|
||||
|
|
@ -4526,7 +4554,7 @@ class ShapeEnv:
|
|||
for fd in expr.atoms(FloorDiv):
|
||||
base, divisor = fd.args
|
||||
if self.replace(Mod(base, divisor)) in self.divisible:
|
||||
div_replacements[fd] = base / divisor
|
||||
div_replacements[fd] = CleanDiv(base, divisor)
|
||||
new_expr = expr.xreplace(div_replacements)
|
||||
new_expr = safe_expand(new_expr)
|
||||
new_pows = new_expr.atoms(sympy.Pow)
|
||||
|
|
@ -4670,7 +4698,10 @@ class ShapeEnv:
|
|||
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
|
||||
|
||||
def issubset(x, y):
|
||||
return (x & int_range).issubset(y & int_range)
|
||||
if x.is_int and y.is_int:
|
||||
return (x & int_range).issubset(y & int_range)
|
||||
else:
|
||||
return x.issubset(y)
|
||||
|
||||
# First, refine the value range of a based on the computed value range
|
||||
# of tgt. This is always OK to do, even if we decide not to do the
|
||||
|
|
@ -4688,7 +4719,7 @@ class ShapeEnv:
|
|||
b = next(iter(tgt.free_symbols))
|
||||
# Try to invert the equality
|
||||
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
|
||||
if r is not None:
|
||||
if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])):
|
||||
b_bound = self.bound_sympy(r[1])
|
||||
self.var_to_range[b] = b_bound & self.var_to_range[b]
|
||||
tgt_bound = self.bound_sympy(tgt)
|
||||
|
|
@ -4899,12 +4930,12 @@ class ShapeEnv:
|
|||
):
|
||||
# We have Mod(i0, q / c) == 0, which means we can
|
||||
# rewrite i0 as (q / gcd(q, c)) * i1
|
||||
d = q / sympy.gcd(q, c)
|
||||
d = q / sympy.gcd(q, c) # TODO: CleanDiv?
|
||||
i1 = self.create_unbacked_symint().node.expr
|
||||
# Propagate the value ranges. It doesn't really
|
||||
# matter if we use truediv or floordiv, because we
|
||||
# have established divisibility.
|
||||
self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv(
|
||||
self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv(
|
||||
self.var_to_range[i0], ValueRanges.wrap(d)
|
||||
))
|
||||
# Propagate size-like-ness
|
||||
|
|
@ -5341,7 +5372,6 @@ class ShapeEnv:
|
|||
lower, upper = vr.lower, vr.upper
|
||||
|
||||
rhs_vr = bound_sympy(rhs, self.var_to_range)
|
||||
_assert_bound_is_rational(rhs, rhs_vr)
|
||||
|
||||
# Let's suppose that we have a preexisting range for x [0, 100].
|
||||
# Now, we issue a guard x > y, where the range for y is [50, 150].
|
||||
|
|
|
|||
|
|
@ -216,10 +216,7 @@ try:
|
|||
def abs(self, number: z3.ArithRef) -> z3.ArithRef:
|
||||
return z3.Abs(number)
|
||||
|
||||
def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef:
|
||||
if ndigits is not None:
|
||||
raise ValueError("round(..., ndigits=) is currently not supported by shape validations.")
|
||||
|
||||
def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef:
|
||||
# Pythons builtin 'round' implements the 'round half to even' strategy
|
||||
# See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
|
||||
# z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to
|
||||
|
|
@ -284,7 +281,7 @@ try:
|
|||
operator.truediv: lift(ops.div),
|
||||
operator.mod: lift(ops.mod),
|
||||
operator.abs: lift(ops.abs),
|
||||
builtins.round: lift(ops.round),
|
||||
builtins.round: lift(ops.round_to_int),
|
||||
|
||||
# Math module.
|
||||
math.ceil: lift(ops.ceil),
|
||||
|
|
@ -350,6 +347,7 @@ try:
|
|||
self._ops = _Z3Ops(self._validator)
|
||||
|
||||
def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
|
||||
# TODO: Probably OK to relax this and allow lower precision
|
||||
if dtype is torch.int64:
|
||||
return z3.IntVal(int(value))
|
||||
if dtype is torch.double:
|
||||
|
|
@ -358,6 +356,20 @@ try:
|
|||
return z3.BoolVal(bool(value))
|
||||
raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
|
||||
|
||||
def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
||||
if dtype == torch.float64:
|
||||
return z3.ToReal(x)
|
||||
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
||||
|
||||
def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
||||
return z3.ToInt(x)
|
||||
|
||||
def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
||||
return self._ops.round_to_int(x)
|
||||
|
||||
def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
return self._ops.div(numerator, denominator)
|
||||
|
||||
def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
return self._ops.div(numerator, denominator)
|
||||
|
||||
|
|
@ -370,11 +382,17 @@ try:
|
|||
def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
|
||||
return self._ops.pow(base, exp)
|
||||
|
||||
def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
|
||||
return self._ops.pow(base, exp)
|
||||
|
||||
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
|
||||
return self._ops.mod(p, q)
|
||||
|
||||
def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef:
|
||||
return self._ops.round(number, ndigits)
|
||||
def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
||||
return self._ops.ceil(x)
|
||||
|
||||
def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
||||
return self._ops.floor(x)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
REPLACEMENT = {
|
||||
|
|
|
|||
|
|
@ -1,43 +1,78 @@
|
|||
import functools
|
||||
import math
|
||||
import sys
|
||||
|
||||
import sympy
|
||||
from sympy import S
|
||||
from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or
|
||||
|
||||
__all__ = [
|
||||
"FloorDiv",
|
||||
"ModularIndexing",
|
||||
"CleanDiv",
|
||||
"CeilDiv",
|
||||
"Pow",
|
||||
"TrueDiv",
|
||||
"IntTrueDiv",
|
||||
"FloatTrueDiv",
|
||||
"LShift",
|
||||
"RShift",
|
||||
"IsNonOverlappingAndDenseIndicator",
|
||||
"Round",
|
||||
"RoundToInt",
|
||||
"RoundDecimal",
|
||||
"ToFloat",
|
||||
"FloatPow",
|
||||
"PowByNatural",
|
||||
]
|
||||
|
||||
|
||||
def _keep_float(f):
|
||||
@functools.wraps(f)
|
||||
def inner(*args):
|
||||
r = f(*args)
|
||||
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
||||
r, sympy.Float
|
||||
):
|
||||
r = sympy.Float(float(r))
|
||||
return r
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def fuzzy_eq(x, y):
|
||||
if None in (x, y):
|
||||
return None
|
||||
return x == y
|
||||
|
||||
|
||||
# It would be nice to have assertions on whether or not inputs is_integer
|
||||
# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
|
||||
# sometimes inconsistently reports floats an integers.
|
||||
#
|
||||
# What we can assume from sympy is that if something is an int, it
|
||||
# definitely is is_integer, but if it is a float it may or may not
|
||||
# be is_integer. So we are unable to do strong asserts that things
|
||||
# are NOT integers.
|
||||
|
||||
|
||||
# TODO: In Triton, // rounds to zero, but in Python, it is floor division.
|
||||
# When we can prove both arguments are non-negative, we should just have a
|
||||
# GenericFloorDiv (name pending) which can codegen efficiently in Python/C,
|
||||
# and then PythonFloorDiv and CIntDiv which have the appropriate rounding
|
||||
# semantics.
|
||||
#
|
||||
# Right now, FloorDiv de facto changes behavior if arguments are negative or
|
||||
# not, this can potentially cause correctness issues.
|
||||
class FloorDiv(sympy.Function):
|
||||
"""
|
||||
We maintain this so that:
|
||||
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
|
||||
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
|
||||
|
||||
NB: This is Python-style floor division, round to -Inf
|
||||
"""
|
||||
|
||||
nargs = (2,)
|
||||
precedence = 50 # precedence of mul # noqa: F811
|
||||
|
||||
# Default return type for SymPy assumptions.
|
||||
# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers
|
||||
is_real = True
|
||||
is_integer = True
|
||||
|
||||
@property
|
||||
def base(self):
|
||||
|
|
@ -52,29 +87,14 @@ class FloorDiv(sympy.Function):
|
|||
divisor = printer.parenthesize(self.divisor, self.precedence)
|
||||
return f"({base}//{divisor})"
|
||||
|
||||
# SymPy assumptions based on argument types.
|
||||
def _eval_is_real(self):
|
||||
return fuzzy_or([self.base.is_real, self.divisor.is_real])
|
||||
|
||||
def _eval_is_integer(self):
|
||||
return fuzzy_and([self.base.is_integer, self.divisor.is_integer])
|
||||
|
||||
# Automatic evaluation.
|
||||
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
|
||||
@classmethod
|
||||
def eval(cls, base, divisor):
|
||||
def check_supported_type(x):
|
||||
if (
|
||||
x.is_integer is False and x.is_real is False and x.is_complex
|
||||
) or x.is_Boolean:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for //: "
|
||||
f"'{type(base).__name__}' and '{type(divisor).__name__}'"
|
||||
f", expected integer or real"
|
||||
)
|
||||
|
||||
check_supported_type(base)
|
||||
check_supported_type(divisor)
|
||||
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
|
||||
# Assert triggered by inequality solver
|
||||
# assert base.is_integer, base
|
||||
# assert divisor.is_integer, divisor
|
||||
|
||||
# We don't provide the same error message as in Python because SymPy
|
||||
# makes it difficult to check the types.
|
||||
|
|
@ -85,26 +105,22 @@ class FloorDiv(sympy.Function):
|
|||
return sympy.S.Zero
|
||||
if base.is_integer and divisor == 1:
|
||||
return base
|
||||
if base.is_real and divisor == 1:
|
||||
return sympy.floor(base)
|
||||
if base.is_integer and divisor == -1:
|
||||
return sympy.Mul(base, -1)
|
||||
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
|
||||
return base // divisor
|
||||
if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(
|
||||
divisor, (sympy.Integer, sympy.Float)
|
||||
):
|
||||
return sympy.floor(base / divisor)
|
||||
return sympy.Integer(int(base) // int(divisor))
|
||||
if isinstance(base, FloorDiv):
|
||||
return FloorDiv(base.args[0], base.args[1] * divisor)
|
||||
if isinstance(divisor, sympy.Rational) and divisor.p == 1:
|
||||
return sympy.floor(base * divisor.q)
|
||||
|
||||
# gcd in sympy is over polynomials, so you'll end up with rationals if
|
||||
# you do this. Don't.
|
||||
"""
|
||||
if isinstance(base, sympy.Add):
|
||||
for a in base.args:
|
||||
gcd = sympy.gcd(a, divisor)
|
||||
if gcd == divisor:
|
||||
return FloorDiv(base - a, divisor) + a / gcd
|
||||
"""
|
||||
|
||||
try:
|
||||
gcd = sympy.gcd(base, divisor)
|
||||
|
|
@ -189,6 +205,19 @@ class Where(sympy.Function):
|
|||
|
||||
nargs = (3,)
|
||||
|
||||
def _eval_is_integer(self):
|
||||
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonnegative(self):
|
||||
return (
|
||||
True
|
||||
if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
|
||||
else None
|
||||
)
|
||||
|
||||
def _eval_is_positive(self):
|
||||
return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def eval(cls, c, p, q):
|
||||
if c == sympy.true:
|
||||
|
|
@ -197,28 +226,27 @@ class Where(sympy.Function):
|
|||
return q
|
||||
|
||||
|
||||
class Mod(sympy.Function):
|
||||
"""
|
||||
We maintain this so that we avoid SymPy correctness issues, such as:
|
||||
https://github.com/sympy/sympy/issues/25146
|
||||
"""
|
||||
|
||||
# Python-style modulus: take sign from RHS
|
||||
class PythonMod(sympy.Function):
|
||||
nargs = (2,)
|
||||
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, p, q):
|
||||
# This was adapted from: sympy/core/mod.py
|
||||
# python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
|
||||
# Triggered by sympy.solvers.inequalities.reduce_inequalities
|
||||
# assert p.is_integer, p
|
||||
# assert q.is_integer, q
|
||||
|
||||
if q.is_zero:
|
||||
raise ZeroDivisionError("Modulo by zero")
|
||||
# If either of them is NaN or infinite.
|
||||
if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False:
|
||||
return S.NaN
|
||||
|
||||
# Three cases:
|
||||
# 1. p == 0
|
||||
# 2. p is either q or -q
|
||||
# 3. p is integer and q == 1
|
||||
if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1):
|
||||
if p is S.Zero or p in (q, -q) or q == 1:
|
||||
return S.Zero
|
||||
|
||||
# Evaluate if they are both literals.
|
||||
|
|
@ -247,10 +275,7 @@ class Mod(sympy.Function):
|
|||
if sympy.Mod(p, q) == 0:
|
||||
return S.Zero
|
||||
|
||||
def _eval_is_integer(self):
|
||||
p, q = self.args
|
||||
return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined]
|
||||
|
||||
# NB: args[1] for PythonMod
|
||||
def _eval_is_nonnegative(self):
|
||||
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
|
||||
|
||||
|
|
@ -258,6 +283,58 @@ class Mod(sympy.Function):
|
|||
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Generic modulus: only defined on non-negative arguments
|
||||
class Mod(sympy.Function):
|
||||
nargs = (2,)
|
||||
|
||||
is_integer = True
|
||||
is_nonnegative = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, p, q):
|
||||
# This was adapted from: sympy/core/mod.py
|
||||
|
||||
# Triggered by
|
||||
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
|
||||
# assert p.is_integer, p
|
||||
# assert q.is_integer, q
|
||||
|
||||
if q.is_zero:
|
||||
raise ZeroDivisionError("Modulo by zero")
|
||||
|
||||
# Three cases:
|
||||
# 1. p == 0
|
||||
# 2. p is either q or -q
|
||||
# 3. p is integer and q == 1
|
||||
if p is S.Zero or p in (q, -q) or q == 1:
|
||||
return S.Zero
|
||||
|
||||
# Evaluate if they are both literals.
|
||||
if q.is_Number and p.is_Number:
|
||||
assert p >= 0, p
|
||||
assert q >= 1, q
|
||||
return p % q
|
||||
|
||||
# If q == 2, it's a matter of whether p is odd or even.
|
||||
if q.is_Number and q == 2:
|
||||
if p.is_even:
|
||||
return S.Zero
|
||||
if p.is_odd:
|
||||
return S.One
|
||||
|
||||
# If p is a multiple of q.
|
||||
r = p / q
|
||||
if r.is_integer:
|
||||
return S.Zero
|
||||
|
||||
# If p < q and its ratio is positive, then:
|
||||
# - floor(p / q) = 0
|
||||
# - p % q = p - floor(p / q) * q = p
|
||||
less = p < q
|
||||
if less.is_Boolean and bool(less) and r.is_positive:
|
||||
return p
|
||||
|
||||
|
||||
class CleanDiv(FloorDiv):
|
||||
"""
|
||||
Div where we can assume no rounding.
|
||||
|
|
@ -267,6 +344,36 @@ class CleanDiv(FloorDiv):
|
|||
pass
|
||||
|
||||
|
||||
# Don't use sympy ceiling/floor as they will attempt simplifications involving
|
||||
# frac
|
||||
class CeilToInt(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, number):
|
||||
# assert number.is_integer is not True, number
|
||||
if number == sympy.oo:
|
||||
return sympy.Integer(sys.maxsize - 1)
|
||||
if number == -sympy.oo:
|
||||
return sympy.Integer(-sys.maxsize - 1)
|
||||
if isinstance(number, sympy.Number):
|
||||
return sympy.Integer(math.ceil(float(number)))
|
||||
|
||||
|
||||
class FloorToInt(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, number):
|
||||
# assert number.is_integer is not True, number
|
||||
if number == sympy.oo:
|
||||
return sympy.Integer(sys.maxsize - 1)
|
||||
if number == -sympy.oo:
|
||||
return sympy.Integer(-sys.maxsize - 1)
|
||||
if isinstance(number, sympy.Number):
|
||||
return sympy.Integer(math.floor(float(number)))
|
||||
|
||||
|
||||
class CeilDiv(sympy.Function):
|
||||
"""
|
||||
Div used in indexing that rounds up.
|
||||
|
|
@ -275,6 +382,8 @@ class CeilDiv(sympy.Function):
|
|||
is_integer = True
|
||||
|
||||
def __new__(cls, base, divisor):
|
||||
base = sympy.sympify(base)
|
||||
divisor = sympy.sympify(divisor)
|
||||
if sympy.gcd(base, divisor) == divisor:
|
||||
return CleanDiv(base, divisor)
|
||||
else:
|
||||
|
|
@ -282,6 +391,8 @@ class CeilDiv(sympy.Function):
|
|||
|
||||
|
||||
class LShift(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, shift):
|
||||
if shift < 0:
|
||||
|
|
@ -290,6 +401,8 @@ class LShift(sympy.Function):
|
|||
|
||||
|
||||
class RShift(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, shift):
|
||||
if shift < 0:
|
||||
|
|
@ -297,28 +410,107 @@ class RShift(sympy.Function):
|
|||
return base // 2**shift
|
||||
|
||||
|
||||
# Overloaded to be compatible with regular Python.
|
||||
# https://github.com/pytorch/pytorch/issues/90900
|
||||
class Pow(sympy.Function):
|
||||
def safe_pow(base, exp):
|
||||
sign = 1
|
||||
if base < 0:
|
||||
base = -base
|
||||
sign = 1 if exp % 2 == 0 else -1
|
||||
return sign * _safe_pow(base, exp)
|
||||
|
||||
|
||||
def _safe_pow(base, exponent):
|
||||
if exponent < 0:
|
||||
raise ValueError("Exponent must be non-negative.")
|
||||
|
||||
if exponent == 0:
|
||||
return 1
|
||||
|
||||
half_exp = safe_pow(base, exponent // 2)
|
||||
if half_exp > sys.maxsize - 1:
|
||||
return sys.maxsize - 1
|
||||
|
||||
result = half_exp * half_exp
|
||||
if result > sys.maxsize - 1:
|
||||
return sys.maxsize - 1
|
||||
|
||||
if exponent % 2 == 1:
|
||||
result *= base
|
||||
if result > sys.maxsize - 1:
|
||||
return sys.maxsize - 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PowByNatural(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, exp):
|
||||
if exp.is_zero:
|
||||
return sympy.Integer(1)
|
||||
elif base.is_zero and exp < 0:
|
||||
raise ZeroDivisionError(f"{base} cannot be raised to a negative power")
|
||||
else:
|
||||
return base**exp
|
||||
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
|
||||
return sympy.Integer(safe_pow(base, exp))
|
||||
if isinstance(exp, sympy.Integer):
|
||||
# Translate power into iterated multiplication
|
||||
r = sympy.Integer(1)
|
||||
for _ in range(int(exp)):
|
||||
r *= base
|
||||
return r
|
||||
# NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
|
||||
# is a natural number if we do
|
||||
|
||||
|
||||
# base is assumed to be nonnegative, thereby prevent complex numbers from
|
||||
# occuring
|
||||
class FloatPow(sympy.Function):
|
||||
is_integer = False
|
||||
is_real = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, exp):
|
||||
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
|
||||
return sympy.Float(float(base) ** float(exp))
|
||||
# NB: do not do any nontrivial reasoning
|
||||
|
||||
|
||||
# Overloaded to be compatible with regular Python.
|
||||
# https://github.com/pytorch/pytorch/issues/90900
|
||||
class TrueDiv(sympy.Function):
|
||||
#
|
||||
# In particular, sympy division is willing to simplify x/x == 1
|
||||
# where 1 is an integer, but this must be a float if x was float.
|
||||
class FloatTrueDiv(sympy.Function):
|
||||
is_integer = False
|
||||
is_real = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, divisor):
|
||||
# assert base.is_integer is not True, base
|
||||
# assert divisor.is_integer is not True, divisor
|
||||
|
||||
if divisor.is_zero:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
|
||||
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
|
||||
return sympy.Float(float(base) / float(divisor))
|
||||
|
||||
|
||||
# Overloaded to be compatible with regular Python. We distinguish this from
|
||||
# FloatTrueDiv, because the code generation has to be different for this case:
|
||||
# Python has a fancy algorithm for integer true division that isn't just
|
||||
# "promote both arguments to float and use float division", so you need to
|
||||
# codegen it differently. While technically you can work it out from the
|
||||
# types of the input, this is often inconvenient to do in Inductor codegen,
|
||||
# so just have a different operator
|
||||
# NB: Right now, Inductor codegen doesn't implement this correctly lol
|
||||
class IntTrueDiv(sympy.Function):
|
||||
is_integer = False
|
||||
is_real = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, divisor):
|
||||
if divisor.is_zero:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
else:
|
||||
return base / divisor
|
||||
|
||||
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
|
||||
return sympy.Float(int(base) / int(divisor))
|
||||
|
||||
|
||||
# TODO: As an indicator, this != 0 implies == 1 (and vice versa).
|
||||
|
|
@ -353,45 +545,85 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function):
|
|||
return None
|
||||
|
||||
|
||||
class Trunc(sympy.Function):
|
||||
# NB: this is inconsistent with math.trunc in Python
|
||||
class TruncToFloat(sympy.Function):
|
||||
is_integer = False
|
||||
is_real = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, number):
|
||||
# assert number.is_integer is not True, number
|
||||
if isinstance(number, sympy.Number):
|
||||
# NB: It is safe to use truncation to integer, which is what
|
||||
# math.trunc does, as Python integers are arbitrary precision and
|
||||
# so we are guaranteed not to lose precision when we do this
|
||||
return sympy.Float(math.trunc(float(number)))
|
||||
|
||||
|
||||
class TruncToInt(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, number):
|
||||
if number.is_integer:
|
||||
return number
|
||||
elif isinstance(number, sympy.Number):
|
||||
# assert number.is_integer is not True, number
|
||||
if number == sympy.oo:
|
||||
return sympy.Integer(sys.maxsize - 1)
|
||||
if number == -sympy.oo:
|
||||
return sympy.Integer(-sys.maxsize - 1)
|
||||
if isinstance(number, sympy.Number):
|
||||
return sympy.Integer(math.trunc(float(number)))
|
||||
|
||||
|
||||
class Round(sympy.Function):
|
||||
# This is float -> int
|
||||
class RoundToInt(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, number):
|
||||
if number.is_integer:
|
||||
return number
|
||||
elif isinstance(number, sympy.Number):
|
||||
return sympy.Integer(round(float(number)))
|
||||
# assert number.is_integer is not True, number
|
||||
|
||||
def __int__(self):
|
||||
# This will only ever be called when computing size hints. At that point, self.args[0] should be a number and
|
||||
# no longer an expression. If it were, the float call would fail and the caller would handle this further.
|
||||
return round(float(self.args[0])) # type: ignore[arg-type]
|
||||
if isinstance(number, sympy.Float):
|
||||
return sympy.Integer(round(float(number), 0))
|
||||
|
||||
|
||||
# To get float -> int, Python style round semantics.
|
||||
#
|
||||
# x = PyFloat_AsDouble(self);
|
||||
# if (o_ndigits == Py_None) {
|
||||
# /* single-argument round or with None ndigits:
|
||||
# * round to nearest integer */
|
||||
# rounded = round(x);
|
||||
# if (fabs(x-rounded) == 0.5)
|
||||
# /* halfway case: round to even */
|
||||
# rounded = 2.0*round(x/2.0);
|
||||
# return PyLong_FromDouble(rounded);
|
||||
# }
|
||||
|
||||
|
||||
# NB: Like Round, this only ever returns floats. ndigits cannot be None
|
||||
class RoundDecimal(sympy.Function):
|
||||
is_integer = False
|
||||
is_real = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, number, ndigits):
|
||||
if number.is_integer and ndigits >= 0:
|
||||
# assert number.is_integer is not True, number
|
||||
|
||||
if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer):
|
||||
return sympy.Float(round(float(number), int(ndigits)))
|
||||
|
||||
|
||||
class ToFloat(sympy.Function):
|
||||
is_integer = False
|
||||
is_real = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, number):
|
||||
if number in [sympy.oo, -sympy.oo]:
|
||||
return number
|
||||
elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
|
||||
value_type, output_type = (
|
||||
(int, sympy.Integer)
|
||||
if isinstance(number, sympy.Integer)
|
||||
else (float, sympy.Float)
|
||||
)
|
||||
return output_type(round(value_type(number), int(ndigits)))
|
||||
|
||||
if isinstance(number, sympy.Integer):
|
||||
return sympy.Float(int(number))
|
||||
|
||||
|
||||
def make_opaque_unary_fn(name):
|
||||
|
|
|
|||
|
|
@ -15,16 +15,23 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
|
|||
|
||||
import torch
|
||||
from .functions import (
|
||||
CeilToInt,
|
||||
CleanDiv,
|
||||
FloatPow,
|
||||
FloatTrueDiv,
|
||||
FloorDiv,
|
||||
FloorToInt,
|
||||
IntTrueDiv,
|
||||
IsNonOverlappingAndDenseIndicator,
|
||||
Mod,
|
||||
ModularIndexing,
|
||||
Pow,
|
||||
Round,
|
||||
PowByNatural,
|
||||
PythonMod,
|
||||
RoundDecimal,
|
||||
TrueDiv,
|
||||
Trunc,
|
||||
RoundToInt,
|
||||
ToFloat,
|
||||
TruncToFloat,
|
||||
TruncToInt,
|
||||
Where,
|
||||
)
|
||||
|
||||
|
|
@ -49,30 +56,39 @@ def handlers():
|
|||
sympy.Le: "le",
|
||||
sympy.Ge: "ge",
|
||||
sympy.Not: "not_",
|
||||
TrueDiv: "truediv",
|
||||
IntTrueDiv: "int_truediv",
|
||||
FloatTrueDiv: "truediv",
|
||||
FloorDiv: "floordiv",
|
||||
CleanDiv: "div",
|
||||
Trunc: "trunc",
|
||||
CleanDiv: "floordiv", # TODO: hmm?
|
||||
TruncToFloat: "trunc",
|
||||
Where: "where",
|
||||
sympy.Add: "add",
|
||||
sympy.Mul: "mul",
|
||||
Pow: "pow",
|
||||
sympy.Pow: "pow",
|
||||
FloatPow: "pow",
|
||||
PowByNatural: "pow_by_natural",
|
||||
# sympy simplifies x * x into Pow(x, 2), so we need to handle this.
|
||||
# Do NOT use builtin Pow for floats
|
||||
# TODO: There is a hazard here, if we have float * float it will
|
||||
# also get turned into Pow(float, 2) but we don't want this because
|
||||
# pow_by_natural is assumed to only be integers. Probably the fix is
|
||||
# to add a FloatMul to impede this optimization
|
||||
sympy.Pow: "pow_by_natural",
|
||||
Mod: "mod",
|
||||
PythonMod: "mod", # TODO: this is wrong
|
||||
# TODO: Inductor can generate these, but it's ill-specified which
|
||||
# semantics were intended here. Needs to be cleaned up along with
|
||||
# FloorDiv in a bigger cleanup
|
||||
sympy.Mod: "mod",
|
||||
sympy.Abs: "abs",
|
||||
sympy.log: "log",
|
||||
sympy.exp: "exp",
|
||||
sympy.floor: "floor",
|
||||
sympy.ceiling: "ceil",
|
||||
sympy.Min: "minimum",
|
||||
sympy.Max: "maximum",
|
||||
ModularIndexing: "modular_indexing",
|
||||
sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
|
||||
sympy.Piecewise: "piecewise",
|
||||
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
|
||||
Round: "round",
|
||||
RoundDecimal: "round",
|
||||
RoundDecimal: "round_decimal",
|
||||
}
|
||||
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
|
||||
HANDLERS[getattr(sympy, name)] = name
|
||||
|
|
@ -84,7 +100,11 @@ ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
|
|||
|
||||
|
||||
def sympy_interp(
|
||||
analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
|
||||
analysis,
|
||||
env: Dict[sympy.Symbol, Any],
|
||||
expr: Union[sympy.Expr, SympyBoolean],
|
||||
*,
|
||||
index_dtype=torch.int64,
|
||||
):
|
||||
# Handle base cases
|
||||
dtype = None
|
||||
|
|
@ -105,9 +125,32 @@ def sympy_interp(
|
|||
expr.args[1], sympy.core.numbers.Half
|
||||
):
|
||||
return analysis.sqrt(sympy_interp(analysis, env, expr.args[0]))
|
||||
if isinstance(expr, ToFloat):
|
||||
return analysis.to_dtype(
|
||||
sympy_interp(analysis, env, expr.args[0]), torch.float64
|
||||
)
|
||||
|
||||
# Recursive case
|
||||
args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type]
|
||||
|
||||
# These handlers are special because they take an extra dtype argument
|
||||
# specifying what they should convert to, and we need to appropriately set
|
||||
# this up when we convert from Sympy. A reasonable default when you
|
||||
# are translating is to conservatively do int64, and then narrow these
|
||||
# arguments later when you discover you can narrow the index range. But
|
||||
# if you already know that 32-bit indexing is OK, you can directly do the
|
||||
# sympy translation with index_dtype=torch.int32
|
||||
INDEX_DTYPE_HANDLERS = {
|
||||
TruncToInt: "trunc_to_int",
|
||||
sympy.floor: "floor_to_int",
|
||||
sympy.ceiling: "ceil_to_int",
|
||||
FloorToInt: "floor_to_int",
|
||||
CeilToInt: "ceil_to_int",
|
||||
RoundToInt: "round_to_int",
|
||||
}
|
||||
if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
|
||||
return getattr(analysis, handler_name)(*args, index_dtype)
|
||||
|
||||
if hasattr(expr.func, "_torch_handler_name"):
|
||||
handler_name = expr.func._torch_handler_name
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,25 @@
|
|||
import math
|
||||
|
||||
import operator
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.utils._sympy.functions import (
|
||||
_keep_float,
|
||||
FloatPow,
|
||||
FloatTrueDiv,
|
||||
FloorDiv,
|
||||
IntTrueDiv,
|
||||
Mod,
|
||||
OpaqueUnaryFn_exp,
|
||||
OpaqueUnaryFn_log,
|
||||
OpaqueUnaryFn_sqrt,
|
||||
PowByNatural,
|
||||
RoundDecimal,
|
||||
RoundToInt,
|
||||
ToFloat,
|
||||
TruncToInt,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -62,20 +75,41 @@ class ReferenceAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def reciprocal(x):
|
||||
return 1 / x
|
||||
return FloatTrueDiv(1.0, x)
|
||||
|
||||
@staticmethod
|
||||
def square(x):
|
||||
return x * x
|
||||
return PowByNatural(x, 2)
|
||||
|
||||
@staticmethod
|
||||
def trunc_to_int(x, dtype):
|
||||
return TruncToInt(x)
|
||||
|
||||
@staticmethod
|
||||
def ceil_to_int(x, dtype):
|
||||
return sympy.ceiling(x)
|
||||
|
||||
@staticmethod
|
||||
def floor_to_int(x, dtype):
|
||||
return sympy.floor(x)
|
||||
|
||||
@staticmethod
|
||||
def floor(x):
|
||||
return _keep_float(sympy.floor)(x)
|
||||
|
||||
@staticmethod
|
||||
def ceil(x):
|
||||
return _keep_float(sympy.ceiling)(x)
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(x, dtype):
|
||||
if dtype == torch.float64:
|
||||
return ToFloat(x)
|
||||
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
||||
|
||||
@staticmethod
|
||||
def mod(x, y):
|
||||
ret = abs(x) % abs(y)
|
||||
# without check:
|
||||
# tracing will fail trying to go through control-flow if x is Proxy()
|
||||
if isinstance(x, (int, sympy.Number)) and x < 0:
|
||||
ret *= -1
|
||||
return ret
|
||||
return Mod(x, y)
|
||||
|
||||
@staticmethod
|
||||
def abs(x):
|
||||
|
|
@ -87,37 +121,31 @@ class ReferenceAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def truediv(a, b):
|
||||
return a / b
|
||||
return FloatTrueDiv(a, b)
|
||||
|
||||
@staticmethod
|
||||
def div(a, b):
|
||||
return ReferenceAnalysis.truediv(a, b)
|
||||
def int_truediv(a, b):
|
||||
return IntTrueDiv(a, b)
|
||||
|
||||
@staticmethod
|
||||
def floordiv(a, b):
|
||||
if b == 0:
|
||||
return sympy.nan if a == 0 else sympy.zoo
|
||||
return a // b
|
||||
return FloorDiv(a, b)
|
||||
|
||||
@staticmethod
|
||||
def truncdiv(a, b):
|
||||
result = a / b
|
||||
if result.is_finite:
|
||||
result = sympy.Integer(result)
|
||||
|
||||
return result
|
||||
raise NotImplementedError("TODO: truncdiv")
|
||||
|
||||
@staticmethod
|
||||
def add(a, b):
|
||||
return a + b
|
||||
return _keep_float(operator.add)(a, b)
|
||||
|
||||
@staticmethod
|
||||
def mul(a, b):
|
||||
return a * b
|
||||
return _keep_float(operator.mul)(a, b)
|
||||
|
||||
@staticmethod
|
||||
def sub(a, b):
|
||||
return a - b
|
||||
return _keep_float(operator.sub)(a, b)
|
||||
|
||||
@staticmethod
|
||||
def exp(x):
|
||||
|
|
@ -133,39 +161,27 @@ class ReferenceAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def pow(a, b):
|
||||
return a**b
|
||||
return _keep_float(FloatPow)(a, b)
|
||||
|
||||
@staticmethod
|
||||
def pow_by_natural(a, b):
|
||||
return PowByNatural(a, b)
|
||||
|
||||
@staticmethod
|
||||
def minimum(a, b):
|
||||
# Poorman's version of upcasting in Sympy
|
||||
# This won't do for sympy.Expr as the casting does nothing for those
|
||||
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
|
||||
result_type = sympy.Float
|
||||
else:
|
||||
assert a.is_Integer
|
||||
assert b.is_Integer
|
||||
result_type = sympy.Integer
|
||||
return sympy.Min(result_type(a), result_type(b))
|
||||
return sympy.Min(a, b)
|
||||
|
||||
@staticmethod
|
||||
def maximum(a, b):
|
||||
# Poorman's version of upcasting in Sympy
|
||||
# This won't do for sympy.Expr as the casting does nothing for those
|
||||
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
|
||||
result_type = sympy.Float
|
||||
else:
|
||||
assert a.is_Integer
|
||||
assert b.is_Integer
|
||||
result_type = sympy.Integer
|
||||
return sympy.Max(result_type(a), result_type(b))
|
||||
return sympy.Max(a, b)
|
||||
|
||||
@staticmethod
|
||||
def floor(x):
|
||||
return sympy.floor(x)
|
||||
def round_to_int(a, dtype):
|
||||
return RoundToInt(a)
|
||||
|
||||
@staticmethod
|
||||
def ceil(x):
|
||||
return sympy.ceiling(x)
|
||||
def round_decimal(a, b):
|
||||
return RoundDecimal(a, b)
|
||||
|
||||
|
||||
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
|
||||
|
|
@ -191,10 +207,20 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
|
|||
def floordiv(a, b):
|
||||
return a // b
|
||||
|
||||
@staticmethod
|
||||
def mod(x, y):
|
||||
return x % y
|
||||
|
||||
@staticmethod
|
||||
def truncdiv(a, b):
|
||||
return a / b
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(x, dtype):
|
||||
if dtype == torch.float64:
|
||||
return float(x)
|
||||
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
||||
|
||||
@staticmethod
|
||||
def exp(x):
|
||||
raise AssertionError("exp is not valid shape sympy expr")
|
||||
|
|
@ -216,9 +242,40 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
|
|||
return torch.sym_max(a, b)
|
||||
|
||||
@staticmethod
|
||||
def floor(x):
|
||||
def floor_to_int(x, dtype):
|
||||
return math.floor(x)
|
||||
|
||||
@staticmethod
|
||||
def ceil(x):
|
||||
def ceil_to_int(x, dtype):
|
||||
return math.ceil(x)
|
||||
|
||||
@staticmethod
|
||||
def floor(x):
|
||||
return float(math.floor(x))
|
||||
|
||||
@staticmethod
|
||||
def ceil(x):
|
||||
return float(math.ceil(x))
|
||||
|
||||
@staticmethod
|
||||
def truediv(a, b):
|
||||
return a / b
|
||||
|
||||
@staticmethod
|
||||
def pow(a, b):
|
||||
return a**b
|
||||
|
||||
@staticmethod
|
||||
def pow_by_natural(a, b):
|
||||
# Pray that safe_pow is not needed here lol. In particular, this
|
||||
# never participates in VR low/high ranges, so overflow should be
|
||||
# unlikely
|
||||
return a**b
|
||||
|
||||
@staticmethod
|
||||
def round_to_int(a, dtype):
|
||||
return round(a)
|
||||
|
||||
@staticmethod
|
||||
def round_decimal(a, b):
|
||||
return round(a, ndigits=b)
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ def try_solve(
|
|||
|
||||
# Return if we were able to isolate 'thing' on the left-hand side.
|
||||
if isinstance(e, sympy.Rel) and e.lhs == thing:
|
||||
log.debug("solved: %s ---> %s", expr, e)
|
||||
return e, e.rhs
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import itertools
|
|||
import logging
|
||||
import math
|
||||
import operator
|
||||
import sys
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
|
|
@ -25,17 +26,20 @@ import torch
|
|||
|
||||
from torch._prims_common import dtype_to_type
|
||||
from .functions import (
|
||||
OpaqueUnaryFn_acos,
|
||||
OpaqueUnaryFn_asinh,
|
||||
OpaqueUnaryFn_atan,
|
||||
OpaqueUnaryFn_cosh,
|
||||
_keep_float,
|
||||
FloatTrueDiv,
|
||||
FloorDiv,
|
||||
IntTrueDiv,
|
||||
OpaqueUnaryFn_exp,
|
||||
OpaqueUnaryFn_log,
|
||||
OpaqueUnaryFn_sinh,
|
||||
OpaqueUnaryFn_sqrt,
|
||||
OpaqueUnaryFn_tanh,
|
||||
Round,
|
||||
PowByNatural,
|
||||
RoundDecimal,
|
||||
RoundToInt,
|
||||
safe_pow,
|
||||
ToFloat,
|
||||
TruncToFloat,
|
||||
TruncToInt,
|
||||
)
|
||||
from .interp import sympy_interp
|
||||
|
||||
|
|
@ -120,6 +124,8 @@ class ValueRanges(Generic[_T]):
|
|||
lower: _T
|
||||
upper: _T
|
||||
is_bool: bool
|
||||
is_int: bool
|
||||
is_float: bool
|
||||
|
||||
@overload
|
||||
def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None:
|
||||
|
|
@ -142,8 +148,39 @@ class ValueRanges(Generic[_T]):
|
|||
# Because this is a frozen class
|
||||
object.__setattr__(self, "lower", lower)
|
||||
object.__setattr__(self, "upper", upper)
|
||||
# Unlike bool/int in Python, we don't report bools are ints
|
||||
object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean))
|
||||
assert isinstance(upper, SympyBoolean) == self.is_bool
|
||||
if self.is_bool:
|
||||
assert isinstance(upper, SympyBoolean), (lower, upper)
|
||||
|
||||
# Warning: is_int/is_float is best effort. We do pretty well in
|
||||
# Dynamo, but in Inductor these attributes are often wrong because we
|
||||
# are not very rigorous in dtype analysis. This is also why we need
|
||||
# the flexible analysis for is_int: sometimes a sympy.oo pops in for
|
||||
# an integer bound. I would /like/ for us not to do this, but it's
|
||||
# too hard to push the invariant through right now.
|
||||
|
||||
object.__setattr__(
|
||||
self,
|
||||
"is_int",
|
||||
not self.is_bool
|
||||
and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)),
|
||||
)
|
||||
"""
|
||||
# This assert is just impossible right now, too many sympy bugs
|
||||
if self.is_int:
|
||||
# NB: sympy will sometimes randomly lose the float-ness of zero,
|
||||
# so we also need to account for that in the assertion here.
|
||||
# See also https://github.com/sympy/sympy/issues/26620
|
||||
assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], (
|
||||
lower,
|
||||
upper,
|
||||
)
|
||||
assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper)
|
||||
"""
|
||||
# NB: [-oo, oo] always advertises as float!
|
||||
object.__setattr__(self, "is_float", not self.is_bool and not self.is_int)
|
||||
assert self.is_bool or self.is_int or self.is_float, (lower, upper)
|
||||
|
||||
def boolify(self) -> ValueRanges[SympyBoolean]:
|
||||
if vr_is_bool(self):
|
||||
|
|
@ -184,6 +221,8 @@ class ValueRanges(Generic[_T]):
|
|||
if self == ValueRanges.unknown():
|
||||
return other
|
||||
assert self.is_bool == other.is_bool, (self, other)
|
||||
assert self.is_int == other.is_int, (self, other)
|
||||
assert self.is_float == other.is_float, (self, other)
|
||||
if self.is_bool:
|
||||
return ValueRanges(
|
||||
sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)
|
||||
|
|
@ -353,7 +392,12 @@ class SymPyValueRangeAnalysis:
|
|||
# using nan makes subsequent computation throw, and for the purposes of optimization
|
||||
# returning -math.inf - math.inf is equivalent to giving up
|
||||
if isinstance(value, SupportsFloat) and math.isnan(value):
|
||||
return ValueRanges.unknown()
|
||||
if dtype == torch.bool:
|
||||
return ValueRanges.unknown_bool()
|
||||
elif dtype.is_floating_point:
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
|
||||
|
||||
if is_python:
|
||||
type_ = dtype_to_type(dtype)
|
||||
|
|
@ -369,7 +413,18 @@ class SymPyValueRangeAnalysis:
|
|||
# dtype is intXX
|
||||
assert value.is_integer
|
||||
|
||||
return ValueRanges.wrap(value)
|
||||
r = ValueRanges.wrap(value)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(a, dtype, src_dtype=None):
|
||||
if dtype == torch.float64:
|
||||
return ValueRanges.increasing_map(a, ToFloat)
|
||||
return ValueRanges.unknown()
|
||||
|
||||
@staticmethod
|
||||
def trunc_to_int(a, dtype):
|
||||
return ValueRanges.increasing_map(a, TruncToInt)
|
||||
|
||||
@staticmethod
|
||||
def not_(a):
|
||||
|
|
@ -428,7 +483,9 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def add(a, b):
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, operator.add)
|
||||
return ValueRanges.coordinatewise_increasing_map(
|
||||
a, b, _keep_float(operator.add)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def mul(cls, a, b):
|
||||
|
|
@ -448,11 +505,20 @@ class SymPyValueRangeAnalysis:
|
|||
else:
|
||||
return a * b
|
||||
|
||||
return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul)
|
||||
return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul))
|
||||
|
||||
@classmethod
|
||||
def div(cls, a, b):
|
||||
return cls.truediv(a, b)
|
||||
@staticmethod
|
||||
def int_truediv(a, b):
|
||||
a = ValueRanges.wrap(a)
|
||||
b = ValueRanges.wrap(b)
|
||||
if 0 in b or (
|
||||
(-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
|
||||
):
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(
|
||||
a, b, _keep_float(IntTrueDiv)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def truediv(a, b):
|
||||
|
|
@ -463,18 +529,22 @@ class SymPyValueRangeAnalysis:
|
|||
):
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv)
|
||||
return ValueRanges.coordinatewise_monotone_map(
|
||||
a, b, _keep_float(FloatTrueDiv)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def floordiv(a, b):
|
||||
a = ValueRanges.wrap(a)
|
||||
b = ValueRanges.wrap(b)
|
||||
if 0 in b or (
|
||||
(-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
|
||||
# TODO: make this more precise
|
||||
(-sympy.oo in a or sympy.oo in a)
|
||||
or (-sympy.oo in b or sympy.oo in b)
|
||||
):
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv)
|
||||
return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv)
|
||||
|
||||
@classmethod
|
||||
def mod(cls, x, y):
|
||||
|
|
@ -523,17 +593,51 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@classmethod
|
||||
def is_non_overlapping_and_dense_indicator(cls, *args):
|
||||
return ValueRanges.unknown()
|
||||
return ValueRanges.unknown() # TODO: type here is wrong
|
||||
|
||||
@classmethod
|
||||
def pow_by_natural(cls, a, b):
|
||||
a = ValueRanges.wrap(a)
|
||||
b = ValueRanges.wrap(b)
|
||||
if a.is_singleton() and b.is_singleton():
|
||||
return ValueRanges.wrap(safe_pow(a.lower, b.lower))
|
||||
# NB: Exclude zero, because zero is special
|
||||
elif a.lower >= 1:
|
||||
# We should know that b >= 0 but we may have forgotten this fact due
|
||||
# to replacements, so don't assert it, but DO clamp it to prevent
|
||||
# degenerate problems
|
||||
return ValueRanges.coordinatewise_increasing_map(
|
||||
a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural
|
||||
)
|
||||
elif b.is_singleton():
|
||||
if b.lower % 2 == 0:
|
||||
# x^n where n is even
|
||||
return ValueRanges.convex_min_zero_map(
|
||||
a, lambda x: safe_pow(x, b.lower)
|
||||
)
|
||||
else:
|
||||
# x^n where n is odd
|
||||
return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower))
|
||||
else:
|
||||
# a is potentially negative, and we don't know if the exponent is
|
||||
# even or odd. So just conservatively set the upper and lower
|
||||
# bound based on what the maximum absolute value could be, in both
|
||||
# directions
|
||||
max_base = max(a.upper, -a.lower)
|
||||
return ValueRanges(
|
||||
-(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pow(cls, a, b):
|
||||
def is_integer(val):
|
||||
return isinstance(val, int) or (
|
||||
hasattr(val, "is_integer") and val.is_integer
|
||||
)
|
||||
return ValueRanges.unknown()
|
||||
|
||||
# We could implement all this, but for floating point pow, is there
|
||||
# really a point?
|
||||
"""
|
||||
a = ValueRanges.wrap(a)
|
||||
b = ValueRanges.wrap(b)
|
||||
|
||||
# Not implemented yet. It's a bit tricky
|
||||
# If you want to implement it, compute the partial derivatives of a ** b
|
||||
# and check the ranges where the function is increasing / decreasing
|
||||
|
|
@ -553,8 +657,7 @@ class SymPyValueRangeAnalysis:
|
|||
if b == 0:
|
||||
if not a.lower.is_finite:
|
||||
return ValueRanges.unknown()
|
||||
type_ = sympy.Float if a.lower.is_real else sympy.Integer
|
||||
return ValueRanges.wrap(type_(1))
|
||||
return ValueRanges.wrap(1.0)
|
||||
|
||||
if b < 0:
|
||||
a = cls.reciprocal(a)
|
||||
|
|
@ -563,21 +666,12 @@ class SymPyValueRangeAnalysis:
|
|||
if a == ValueRanges.unknown():
|
||||
return ValueRanges.unknown()
|
||||
|
||||
# Here b > 0
|
||||
if not is_integer(b):
|
||||
# If the base is positive, then we're good, otherwise nothing's defined
|
||||
if a.lower >= 0:
|
||||
return ValueRanges.increasing_map(a, lambda x: x**b)
|
||||
else:
|
||||
return ValueRanges.unknown()
|
||||
# If the base is positive, then we're good, otherwise nothing's defined
|
||||
if a.lower >= 0:
|
||||
return ValueRanges.increasing_map(a, lambda x: x**b)
|
||||
else:
|
||||
# b > 0 integer
|
||||
if b % 2 == 0:
|
||||
# x^n where n is even
|
||||
return ValueRanges.convex_min_zero_map(a, lambda x: x**b)
|
||||
else:
|
||||
# x^n where n is odd
|
||||
return ValueRanges.increasing_map(a, lambda x: x**b)
|
||||
return ValueRanges.unknown()
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def reciprocal(x):
|
||||
|
|
@ -586,7 +680,7 @@ class SymPyValueRangeAnalysis:
|
|||
if 0 in x:
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.decreasing_map(x, lambda y: 1 / y)
|
||||
return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y))
|
||||
|
||||
@staticmethod
|
||||
def abs(x):
|
||||
|
|
@ -615,45 +709,64 @@ class SymPyValueRangeAnalysis:
|
|||
def min_or_max(a, b, fn):
|
||||
a = ValueRanges.wrap(a)
|
||||
b = ValueRanges.wrap(b)
|
||||
|
||||
# Performs upcasting first
|
||||
def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr:
|
||||
# Poorman's version of upcasting in Sympy
|
||||
# Inf is not a float...
|
||||
if x.is_Integer and y.is_Integer:
|
||||
result_type = sympy.Integer
|
||||
elif x.is_rational and y.is_rational:
|
||||
result_type = sympy.Rational
|
||||
else:
|
||||
assert x.is_real or not x.is_finite or y.is_real or not y.is_finite
|
||||
result_type = sympy.Float
|
||||
return fn(result_type(x), result_type(y))
|
||||
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, fn_)
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, fn)
|
||||
|
||||
@classmethod
|
||||
def floor(cls, x):
|
||||
def floor_to_int(cls, x, dtype):
|
||||
return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor)
|
||||
|
||||
@classmethod
|
||||
def ceil(cls, x):
|
||||
def ceil_to_int(cls, x, dtype):
|
||||
return ValueRanges.increasing_map(
|
||||
x, sympy.functions.elementary.integers.ceiling
|
||||
)
|
||||
|
||||
# I think these implementations are sound. The hazard here is that sympy
|
||||
# will carry out the floor/ceil at too high precision and then something
|
||||
# bad will happen when we convert it to float.
|
||||
#
|
||||
# For truncation, the implementation is clearly sound, because the desired
|
||||
# target float is always exactly representable, since you're just chopping
|
||||
# off bits the mantissa. But what about ceil/floor?
|
||||
#
|
||||
# The important constraint here is that we're not defining floor on
|
||||
# arbitrary real numbers, only representable float numbers. So we can
|
||||
# take advantage of the fact that before we reach the first
|
||||
# unrepresentable integer in floating point space, we have the range of
|
||||
# numbers corresponding to exponent zero: all integers, with no fractional
|
||||
# amounts. floor/ceil is an identity operation in this case. In the
|
||||
# range below here, representable floating point numbers are spaced
|
||||
# exactly 1/2 apart, and notably, both the floor/ceil are defined floating
|
||||
# point numbers. There is no "gap" as you step up to the next exponent.
|
||||
|
||||
@classmethod
|
||||
def round(cls, number, ndigits=None):
|
||||
if ndigits is None:
|
||||
fn = Round
|
||||
else:
|
||||
assert ndigits.is_singleton()
|
||||
ndigits = ndigits.lower
|
||||
# We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
|
||||
# the second parameter.
|
||||
fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731
|
||||
def floor(cls, x):
|
||||
return ValueRanges.increasing_map(
|
||||
x, _keep_float(sympy.functions.elementary.integers.floor)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def ceil(cls, x):
|
||||
return ValueRanges.increasing_map(
|
||||
x, _keep_float(sympy.functions.elementary.integers.ceiling)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def round_decimal(cls, number, ndigits):
|
||||
if not ndigits.is_singleton():
|
||||
return ValueRanges.unknown()
|
||||
|
||||
ndigits = ndigits.lower
|
||||
# We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
|
||||
# the second parameter.
|
||||
fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731
|
||||
|
||||
return ValueRanges.increasing_map(number, fn)
|
||||
|
||||
@classmethod
|
||||
def round_to_int(cls, number, dtype):
|
||||
return ValueRanges.increasing_map(number, RoundToInt)
|
||||
|
||||
# It's used in some models on symints
|
||||
@staticmethod
|
||||
def sqrt(x):
|
||||
|
|
@ -708,12 +821,15 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def cosh(x):
|
||||
return ValueRanges(0.0, sympy.oo)
|
||||
"""
|
||||
x = ValueRanges.wrap(x)
|
||||
if x.lower > 0:
|
||||
return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh)
|
||||
elif x.upper < 0:
|
||||
return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh)
|
||||
return ValueRanges(0.0, sympy.oo)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def sin(x):
|
||||
|
|
@ -723,7 +839,8 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def sinh(x):
|
||||
return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh)
|
||||
# return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh)
|
||||
return ValueRanges(-sympy.oo, sympy.oo)
|
||||
|
||||
@staticmethod
|
||||
def tan(x):
|
||||
|
|
@ -731,32 +848,37 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def tanh(x):
|
||||
return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh)
|
||||
# return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh)
|
||||
return ValueRanges(-sympy.oo, sympy.oo)
|
||||
|
||||
@staticmethod
|
||||
def asin(x):
|
||||
return ValueRanges(-sympy.oo, sympy.oo)
|
||||
"""
|
||||
x = ValueRanges.wrap(x)
|
||||
if -1 <= x.lower and x.upper <= 1:
|
||||
return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh)
|
||||
return ValueRanges.unknown()
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def acos(x):
|
||||
return ValueRanges(-sympy.oo, sympy.oo)
|
||||
"""
|
||||
x = ValueRanges.wrap(x)
|
||||
if -1 <= x.lower and x.upper <= 1:
|
||||
return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos)
|
||||
return ValueRanges.unknown()
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def atan(x):
|
||||
return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)
|
||||
return ValueRanges(-sympy.oo, sympy.oo)
|
||||
# return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)
|
||||
|
||||
@staticmethod
|
||||
def trunc(x):
|
||||
def trunc(x):
|
||||
return sympy.Integer(x) if x.is_finite else x
|
||||
|
||||
return ValueRanges.increasing_map(x, trunc)
|
||||
return ValueRanges.increasing_map(x, TruncToFloat)
|
||||
|
||||
|
||||
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
||||
|
|
@ -791,9 +913,10 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
|||
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
def index_expr(self, index, dtype):
|
||||
@classmethod
|
||||
def index_expr(cls, index, dtype):
|
||||
assert isinstance(index, ValueRanges)
|
||||
return index
|
||||
return cls.to_dtype(index, dtype)
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None):
|
||||
|
|
@ -830,12 +953,15 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
|||
|
||||
@staticmethod
|
||||
def square(x):
|
||||
return ValueRanges.convex_min_zero_map(x, lambda y: y * y)
|
||||
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
|
||||
|
||||
@staticmethod
|
||||
def neg(x):
|
||||
return ValueRanges.decreasing_map(x, operator.neg)
|
||||
|
||||
# TODO: this is slightly inaccurate because truncdiv operates at integer
|
||||
# precision, but we're going through float truediv which means we can
|
||||
# potentially lose precision on the bounds
|
||||
@classmethod
|
||||
def truncdiv(cls, a, b):
|
||||
x = cls.truediv(a, b)
|
||||
|
|
@ -856,6 +982,7 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
|||
def bound_sympy(
|
||||
expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None
|
||||
) -> ValueRanges:
|
||||
log.debug("bound_sympy(%s, %s)", expr, ranges)
|
||||
if isinstance(expr, sympy.Number):
|
||||
return ValueRanges.wrap(expr)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user