mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Introduce int_oo (#127693)
In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range. After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better. But especially after #126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation. The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments. Fixes https://github.com/pytorch/pytorch/issues/127396 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/127693 Approved by: https://github.com/lezcano ghstack dependencies: #126905
This commit is contained in:
parent
db2fa7b827
commit
9cab5987bd
|
|
@ -253,7 +253,6 @@ Target Expressions:
|
|||
==> (>= 0 s1)
|
||||
==> (>= 0 s2)
|
||||
==> (>= 0 s3)
|
||||
==> (>= 9223372036854775806 s0)
|
||||
|
||||
Failed Source Expressions:
|
||||
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
|
||||
|
|
@ -287,14 +286,14 @@ Failure occurred while running node:
|
|||
Model:
|
||||
==> L['shape'][0]: 1
|
||||
==> L['shape'][1]: 1
|
||||
==> L['shape'][2]: 2
|
||||
==> L['shape'][2]: 0
|
||||
==> L['x'].size()[0]: 3
|
||||
==> L['x'].storage_offset(): 0
|
||||
==> L['x'].stride()[0]: 1
|
||||
==> s0: 3
|
||||
==> s1: 1
|
||||
==> s2: 1
|
||||
==> s3: 2
|
||||
==> s3: 0
|
||||
|
||||
Assertions:
|
||||
==> (== 0 L['x'].storage_offset())
|
||||
|
|
@ -318,10 +317,6 @@ Target Expressions:
|
|||
==> (== L['shape'][2] s3)
|
||||
==> (== L['x'].size()[0] s0)
|
||||
==> (> s0 0)
|
||||
==> (>= 9223372036854775806 s0)
|
||||
==> (>= 9223372036854775807 s1)
|
||||
==> (>= 9223372036854775807 s2)
|
||||
==> (>= 9223372036854775807 s3)
|
||||
|
||||
Failed Source Expressions:
|
||||
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
|
||||
|
|
|
|||
|
|
@ -3473,7 +3473,6 @@ class GraphModule(torch.nn.Module):
|
|||
]
|
||||
false_guard_code = [
|
||||
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
|
||||
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
|
||||
]
|
||||
test_symbool_guards(
|
||||
f,
|
||||
|
|
|
|||
|
|
@ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
|
||||
> Right: {0: 0, 1: 1}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
|
||||
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
> Right: {}
|
||||
==> var_to_sources: values don't match.
|
||||
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
|
||||
|
|
@ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: 2
|
||||
> Right: 0
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]}
|
||||
> Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
|
||||
> Right: {}
|
||||
""",
|
||||
)
|
||||
|
|
@ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: {s0: 3}
|
||||
> Right: {}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]}
|
||||
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
|
||||
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
|
||||
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
|
@ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]}
|
||||
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
|
||||
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
|
||||
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
|
|
|||
|
|
@ -201,6 +201,19 @@ class TestDynamismExpression(TestCase):
|
|||
dynamic_shapes={"x": {0: dim_x}},
|
||||
)
|
||||
|
||||
def test_export_slice_maxsize(self):
|
||||
class Slice(torch.nn.Module):
|
||||
def forward(self, *args):
|
||||
return torch.ops.aten.slice.Tensor(*args)
|
||||
|
||||
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
|
||||
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
|
||||
torch.export.export(
|
||||
Slice(),
|
||||
inp,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_export_constraints_error(self):
|
||||
class ConflictingConstraints(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -5183,7 +5196,7 @@ def forward(self, x, y):
|
|||
}
|
||||
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_disable_forced_specializations(self):
|
||||
def test_disable_forced_specializations_ok(self):
|
||||
# check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags
|
||||
# both behave correctly, avoiding forced specializations and deferring to runtime.
|
||||
# case 1: modulo guards
|
||||
|
|
|
|||
|
|
@ -633,10 +633,6 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
func, (torch.randn(3, 4),)
|
||||
)
|
||||
|
||||
@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
|
||||
error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}.",
|
||||
reason="https://github.com/pytorch/pytorch/issues/112622",
|
||||
)
|
||||
def test_operator_with_scalar_output(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
|
|
|||
|
|
@ -381,6 +381,17 @@ class TestPySymInt(TestCase):
|
|||
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
||||
|
||||
def test_floordiv_static(self):
|
||||
shape_env = ShapeEnv()
|
||||
s0 = create_symint(shape_env, 8)
|
||||
# This was extracted from
|
||||
# python test/inductor/test_cuda_cpp_wrapper.py -k
|
||||
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
|
||||
bool(s0 % 2 == 0)
|
||||
bool(s0 % (s0 // 2) == 0)
|
||||
bool(2 * (s0 // 2) == s0)
|
||||
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
|
||||
|
||||
def test_numel(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
|
|
|
|||
|
|
@ -1201,7 +1201,9 @@ def forward(self, x_1):
|
|||
batch_size = 4
|
||||
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
|
||||
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
|
||||
self.assertEqual(len(gm.shape_env.guards), 0)
|
||||
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
|
||||
# 1 ok)
|
||||
self.assertEqual(len(gm.shape_env.guards), 1)
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
|
||||
def test_cpu_scalar_cuda(self):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: pt2"]
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import sys
|
||||
|
||||
import sympy
|
||||
|
|
@ -19,6 +20,7 @@ from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
|||
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
|
||||
from torch.utils._sympy.interp import sympy_interp
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
|
||||
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
|
||||
import functools
|
||||
import torch.fx as fx
|
||||
|
|
@ -122,6 +124,74 @@ def generate_range(vals):
|
|||
yield ValueRanges(a1, a2)
|
||||
|
||||
|
||||
class TestNumbers(TestCase):
|
||||
def test_int_infinity(self):
|
||||
self.assertIsInstance(int_oo, IntInfinity)
|
||||
self.assertIsInstance(-int_oo, NegativeIntInfinity)
|
||||
self.assertTrue(int_oo.is_integer)
|
||||
# is tests here are for singleton-ness, don't use it for comparisons
|
||||
# against numbers
|
||||
self.assertIs(int_oo + int_oo, int_oo)
|
||||
self.assertIs(int_oo + 1, int_oo)
|
||||
self.assertIs(int_oo - 1, int_oo)
|
||||
self.assertIs(-int_oo - 1, -int_oo)
|
||||
self.assertIs(-int_oo + 1, -int_oo)
|
||||
self.assertIs(-int_oo + (-int_oo), -int_oo)
|
||||
self.assertIs(-int_oo - int_oo, -int_oo)
|
||||
self.assertIs(1 + int_oo, int_oo)
|
||||
self.assertIs(1 - int_oo, -int_oo)
|
||||
self.assertIs(int_oo * int_oo, int_oo)
|
||||
self.assertIs(2 * int_oo, int_oo)
|
||||
self.assertIs(int_oo * 2, int_oo)
|
||||
self.assertIs(-1 * int_oo, -int_oo)
|
||||
self.assertIs(-int_oo * int_oo, -int_oo)
|
||||
self.assertIs(2 * -int_oo, -int_oo)
|
||||
self.assertIs(-int_oo * 2, -int_oo)
|
||||
self.assertIs(-1 * -int_oo, int_oo)
|
||||
self.assertIs(int_oo / 2, sympy.oo)
|
||||
self.assertIs(-(-int_oo), int_oo) # noqa: B002
|
||||
self.assertIs(abs(int_oo), int_oo)
|
||||
self.assertIs(abs(-int_oo), int_oo)
|
||||
self.assertIs(int_oo ** 2, int_oo)
|
||||
self.assertIs((-int_oo) ** 2, int_oo)
|
||||
self.assertIs((-int_oo) ** 3, -int_oo)
|
||||
self.assertEqual(int_oo ** -1, 0)
|
||||
self.assertEqual((-int_oo) ** -1, 0)
|
||||
self.assertIs(int_oo ** int_oo, int_oo)
|
||||
self.assertTrue(int_oo == int_oo)
|
||||
self.assertFalse(int_oo != int_oo)
|
||||
self.assertTrue(-int_oo == -int_oo)
|
||||
self.assertFalse(int_oo == 2)
|
||||
self.assertTrue(int_oo != 2)
|
||||
self.assertFalse(int_oo == sys.maxsize)
|
||||
self.assertTrue(int_oo >= sys.maxsize)
|
||||
self.assertTrue(int_oo >= 2)
|
||||
self.assertTrue(int_oo >= -int_oo)
|
||||
|
||||
def test_relation(self):
|
||||
self.assertIs(sympy.Add(2, int_oo), int_oo)
|
||||
self.assertFalse(-int_oo > 2)
|
||||
|
||||
def test_lt_self(self):
|
||||
self.assertFalse(int_oo < int_oo)
|
||||
self.assertIs(min(-int_oo, -4), -int_oo)
|
||||
self.assertIs(min(-int_oo, -int_oo), -int_oo)
|
||||
|
||||
def test_float_cast(self):
|
||||
self.assertEqual(float(int_oo), math.inf)
|
||||
self.assertEqual(float(-int_oo), -math.inf)
|
||||
|
||||
def test_mixed_oo_int_oo(self):
|
||||
# Arbitrary choice
|
||||
self.assertTrue(int_oo < sympy.oo)
|
||||
self.assertFalse(int_oo > sympy.oo)
|
||||
self.assertTrue(sympy.oo > int_oo)
|
||||
self.assertFalse(sympy.oo < int_oo)
|
||||
self.assertIs(max(int_oo, sympy.oo), sympy.oo)
|
||||
self.assertTrue(-int_oo > -sympy.oo)
|
||||
self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)
|
||||
|
||||
|
||||
class TestValueRanges(TestCase):
|
||||
@parametrize("fn", UNARY_OPS)
|
||||
@parametrize("dtype", ("int", "float"))
|
||||
|
|
|
|||
|
|
@ -734,6 +734,11 @@ def slice_forward(
|
|||
end: Optional[int] = None,
|
||||
step: int = 1,
|
||||
):
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_size_oblivious,
|
||||
statically_known_true,
|
||||
)
|
||||
|
||||
ndim = self.dim()
|
||||
if ndim == 0:
|
||||
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
|
||||
|
|
@ -760,7 +765,9 @@ def slice_forward(
|
|||
|
||||
if end_val < start_val:
|
||||
end_val = start_val
|
||||
elif end_val > sizes[dim]:
|
||||
elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
|
||||
end_val > sizes[dim]
|
||||
):
|
||||
end_val = sizes[dim]
|
||||
|
||||
storage_offset = self.storage_offset() + start_val * strides[dim]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import sympy
|
|||
import torch
|
||||
import torch.fx
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
||||
|
||||
|
|
@ -23,9 +24,9 @@ class InputDim(NamedTuple):
|
|||
|
||||
def _convert_to_int(val):
|
||||
# Convert simple sympy Integers into concrete int
|
||||
if val == sympy.oo:
|
||||
if val in (sympy.oo, int_oo):
|
||||
return math.inf
|
||||
if val == -sympy.oo:
|
||||
if val in (-sympy.oo, -int_oo):
|
||||
return -math.inf
|
||||
if isinstance(val, sympy.Integer):
|
||||
return int(val)
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from torch.fx.experimental import symbolic_shapes
|
|||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._pytree import treespec_dumps, treespec_loads
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
from .schema import ( # type: ignore[attr-defined]
|
||||
Argument,
|
||||
|
|
@ -321,9 +322,9 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]
|
|||
|
||||
def _sympy_int_to_int(val: sympy.Expr, adjust: str):
|
||||
# Convert simple sympy Integers into concrete int
|
||||
if val == sympy.oo:
|
||||
if val in (sympy.oo, int_oo):
|
||||
return math.inf
|
||||
if val == -sympy.oo:
|
||||
if val in (-sympy.oo, -int_oo):
|
||||
return -math.inf
|
||||
if isinstance(val, sympy.Integer):
|
||||
return int(val)
|
||||
|
|
@ -346,9 +347,9 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str):
|
|||
def _int_to_sympy_int(val) -> sympy.Expr:
|
||||
# Convert concrete int into simple sympy Integers
|
||||
if val == math.inf:
|
||||
return sympy.oo
|
||||
return int_oo
|
||||
if val == -math.inf:
|
||||
return -sympy.oo
|
||||
return -int_oo
|
||||
return sympy.Integer(val)
|
||||
|
||||
|
||||
|
|
@ -1826,7 +1827,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
self.symbol_name_to_range = {}
|
||||
if symbol_name_to_range:
|
||||
for k, vr in symbol_name_to_range.items():
|
||||
lower = int(vr.lower)
|
||||
lower = vr.lower
|
||||
if vr.upper >= 2: # max is >= 2, not sym bool range
|
||||
lower = max(2, lower)
|
||||
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
SymTypes,
|
||||
)
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
from . import config, ir
|
||||
from .codegen.common import (
|
||||
|
|
@ -1427,18 +1428,21 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
vr = shape_env.var_to_range[i0]
|
||||
if not shape_env._default_unspecified_value_range().issubset(vr):
|
||||
|
||||
def convert(s):
|
||||
def is_convertible(s):
|
||||
if s in (int_oo, -int_oo):
|
||||
return False
|
||||
try:
|
||||
return int(s)
|
||||
int(s)
|
||||
return True
|
||||
except TypeError:
|
||||
return None
|
||||
return False
|
||||
|
||||
if (lower := convert(vr.lower)) is not None:
|
||||
if is_convertible(vr.lower):
|
||||
self.register_buffer(
|
||||
ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"),
|
||||
set_name=True,
|
||||
)
|
||||
if (upper := convert(vr.upper)) is not None:
|
||||
if is_convertible(vr.upper):
|
||||
self.register_buffer(
|
||||
ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"),
|
||||
set_name=True,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
|
|
@ -41,9 +40,11 @@ class _Dim(type):
|
|||
|
||||
@staticmethod
|
||||
def readable(name, min_, max_):
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
if min_ == 2:
|
||||
min_ = None
|
||||
if max_ == sys.maxsize - 1:
|
||||
if max_ == int_oo:
|
||||
max_ = None
|
||||
if min_ is None and max_ is None:
|
||||
return f"Dim('{name}')"
|
||||
|
|
@ -140,6 +141,11 @@ class _DerivedDim(_Dim):
|
|||
# TODO(avik): use sympy value range analysis instead?
|
||||
from sympy import Integer
|
||||
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
if self.root.min is -int_oo: # type: ignore[attr-defined]
|
||||
return -int_oo # fn not needed cuz increasing
|
||||
|
||||
_min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined]
|
||||
root = self.root # type: ignore[attr-defined]
|
||||
assert _min_symint >= 0, (
|
||||
|
|
@ -155,6 +161,11 @@ class _DerivedDim(_Dim):
|
|||
# TODO(avik): use sympy value range analysis instead?
|
||||
from sympy import Integer
|
||||
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
if self.root.max is int_oo: # type: ignore[attr-defined]
|
||||
return int_oo # fn not needed cuz increasing
|
||||
|
||||
_max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined]
|
||||
root = self.root # type: ignore[attr-defined]
|
||||
assert _max_symint <= sys.maxsize - 1, (
|
||||
|
|
@ -190,8 +201,10 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
|
|||
Returns:
|
||||
A type that can be used in dynamic shape specifications for tensors.
|
||||
"""
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
_min = 0 if min is None else min
|
||||
_max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1)
|
||||
_max = int_oo if max is None else max
|
||||
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
|
||||
dim = _Dim(name, (int,), {"min": _min, "max": _max})
|
||||
dim.__module__ = getattr(
|
||||
|
|
@ -269,10 +282,11 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
|
|||
def _clone_with_range(self, lower=0, upper=None):
|
||||
# Import sympy locally
|
||||
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
if upper is None:
|
||||
upper = sys.maxsize - 1
|
||||
upper = int_oo
|
||||
|
||||
constraint_range = StrictMinMaxConstraint(
|
||||
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
|
||||
|
|
@ -503,15 +517,14 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
|
|||
# Import sympy locally
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
return _create_constraint(
|
||||
weakref.ref(t),
|
||||
id(t),
|
||||
index,
|
||||
StrictMinMaxConstraint(
|
||||
vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False
|
||||
),
|
||||
StrictMinMaxConstraint(vr=ValueRanges(lower=0, upper=int_oo), warn_only=False),
|
||||
debug_name=debug_name,
|
||||
)
|
||||
|
||||
|
|
@ -725,6 +738,7 @@ def _process_dynamic_shapes(
|
|||
import sympy
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.solve import try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
|
|
@ -799,7 +813,7 @@ def _process_dynamic_shapes(
|
|||
constraint = dynamic_dim(tensor, i, debug_name=dim.__name__)
|
||||
if dim.min != 0:
|
||||
constraint = constraint >= dim.min
|
||||
if dim.max != sys.maxsize - 1:
|
||||
if dim.max != int_oo:
|
||||
constraint = constraint <= dim.max
|
||||
return constraint
|
||||
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ from torch.utils._sympy.functions import (
|
|||
FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt
|
||||
)
|
||||
from torch.utils._sympy.solve import try_solve
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
from torch.utils._traceback import format_frame, CapturedTraceback
|
||||
|
|
@ -871,9 +872,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
|||
for N=1.
|
||||
"""
|
||||
if min is None:
|
||||
min = -sys.maxsize - 1
|
||||
min = -int_oo
|
||||
if max is None:
|
||||
max = sys.maxsize - 1
|
||||
max = int_oo
|
||||
|
||||
if max < min:
|
||||
raise ValueError(
|
||||
|
|
@ -1382,6 +1383,7 @@ SYMPY_INTERP = {
|
|||
'PythonMod': operator.mod,
|
||||
'FloorDiv': operator.floordiv,
|
||||
'TrueDiv': operator.truediv,
|
||||
'PowByNatural': operator.pow,
|
||||
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
|
||||
'floor': math.floor,
|
||||
'ceiling': math.ceil,
|
||||
|
|
@ -1994,7 +1996,7 @@ class DimConstraints:
|
|||
(dim.min < 2 and c.get("min", 2) == 2)
|
||||
or dim.min == c.get("min", 2)
|
||||
) # let pass if analysis min = 2 and specified min = 0/1
|
||||
and dim.max == c.get("max", sys.maxsize - 1)
|
||||
and dim.max == c.get("max", int_oo)
|
||||
)
|
||||
|
||||
# 1) newly introduced roots
|
||||
|
|
@ -2017,7 +2019,7 @@ class DimConstraints:
|
|||
modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
|
||||
c_min = c.get("min", 2)
|
||||
min_ = math.ceil((c_min - remainder) / modulus)
|
||||
c_max = c.get("max", sys.maxsize - 1)
|
||||
c_max = c.get("max", int_oo)
|
||||
max_ = math.floor((c_max - remainder) / modulus)
|
||||
# create result & dim
|
||||
results[str(root)] = {"min": min_, "max": max_}
|
||||
|
|
@ -2765,7 +2767,7 @@ class ShapeEnv:
|
|||
if min is None:
|
||||
min = 0
|
||||
if max is None:
|
||||
max = sys.maxsize - 1
|
||||
max = int_oo
|
||||
|
||||
if max < min:
|
||||
raise ValueError(
|
||||
|
|
@ -4094,7 +4096,7 @@ class ShapeEnv:
|
|||
|
||||
assert sources
|
||||
bounds = []
|
||||
if r.lower != -sympy.oo:
|
||||
if r.lower not in (-sympy.oo, -int_oo):
|
||||
if any(is_dim(source) for source in sources):
|
||||
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
|
||||
# Only print lower bound in simplified mode if it is not the
|
||||
|
|
@ -4102,14 +4104,7 @@ class ShapeEnv:
|
|||
if not _simplified or r.lower != self._default_value_range().lower:
|
||||
bounds.append(str(r.lower))
|
||||
bounds.append(source_ref(sources[0]))
|
||||
# NB: This looks like an off-by-one error but it's not: the
|
||||
# upper bound may be sys.maxsize - 1 because we intentionally
|
||||
# exclude sys.maxsize from our bounds to deal with direct
|
||||
# == INT_MAX guards, but it's still dumb to actually test it.
|
||||
# Note that you can be off by a pretty large constant and it
|
||||
# won't matter because sizes in practice will be no where near
|
||||
# the 64-bit limit.
|
||||
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
|
||||
if r.upper not in (sympy.oo, int_oo):
|
||||
if any(is_dim(source) for source in sources):
|
||||
self.dim_constraints.add(sympy.Le(symbol, r.upper))
|
||||
# nontrivial upper bound is always interesting
|
||||
|
|
@ -4121,9 +4116,8 @@ class ShapeEnv:
|
|||
constraints = symbol_to_constraints[symbol]
|
||||
for c in constraints:
|
||||
if isinstance(c, StrictMinMaxConstraint):
|
||||
# NB: By default, we have a restrictive range
|
||||
# 2 <= s0 <= sys.maxsize - 1. But export users generally
|
||||
# expect to be able to specify nice ranges like [0, oo]
|
||||
# TODO: With int_oo, I think this condition is a noop
|
||||
# now
|
||||
if not (c.vr & self._default_value_range()).issubset(r):
|
||||
source = sources[0]
|
||||
|
||||
|
|
@ -4196,9 +4190,9 @@ class ShapeEnv:
|
|||
# Reason: '_maybe_evaluate_static' may eliminate guards based on the
|
||||
# refined value ranges.
|
||||
for sym, vr in self.var_to_range.items():
|
||||
if vr.lower != -sympy.oo:
|
||||
if vr.lower not in (-sympy.oo, -int_oo):
|
||||
self._add_target_expr(sympy.Le(vr.lower, sym))
|
||||
if vr.upper != sympy.oo:
|
||||
if vr.upper not in (sympy.oo, int_oo):
|
||||
self._add_target_expr(sympy.Le(sym, vr.upper))
|
||||
|
||||
# Before validating, populate the input of the validator with the
|
||||
|
|
@ -4330,9 +4324,14 @@ class ShapeEnv:
|
|||
var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
|
||||
if size_oblivious:
|
||||
# Clamp values of size-like variables
|
||||
# NB: discarding the old upper bound in intentional, per
|
||||
# https://github.com/pytorch/pytorch/pull/123675
|
||||
for x in self.size_like & var_to_range.keys():
|
||||
if var_to_range[x] is not None:
|
||||
var_to_range[x] = ValueRanges(2, sys.maxsize - 1)
|
||||
# NB: do NOT set upper to 2 ** 48, we're using this solely
|
||||
# to determine if we can do size-like replacement, the
|
||||
# upper bound is irrelevant here
|
||||
var_to_range[x] = ValueRanges(2, int_oo)
|
||||
assert var_to_range[x].is_int
|
||||
return bound_sympy(expr, var_to_range)
|
||||
|
||||
|
|
@ -4450,18 +4449,25 @@ class ShapeEnv:
|
|||
vr = self._default_unspecified_value_range()
|
||||
if size_oblivious and k in self.size_like:
|
||||
lower = max(2, vr.lower)
|
||||
# Clamping size-oblivious to some quantity below sys.maxsize
|
||||
# helps us determine that f(u0) != sys.maxsize, which is a
|
||||
# test that is looking for sys.maxsize as a sentinel, but you
|
||||
# don't really want to worry about it for unbacked SymInts.
|
||||
# This is similar to the flavor where size oblivious omits
|
||||
# 0/1, it changes semantics but in a benign way.
|
||||
upper = min(2 ** 48, vr.upper)
|
||||
# This is a bit dodgy: what this means is that there was a
|
||||
# size-like unbacked symbol whose upper bound < 2. This
|
||||
# causes... problems.
|
||||
if lower <= vr.upper:
|
||||
vr = ValueRanges(lower, vr.upper)
|
||||
if lower <= upper:
|
||||
vr = ValueRanges(lower, upper)
|
||||
else:
|
||||
lower = vr.lower
|
||||
# Don't do anything if we don't have a nontrivial lower bound
|
||||
# Also don't do anything if we asked only to simplify unbacked
|
||||
# SymInt
|
||||
if (
|
||||
lower < (-sys.maxsize - 1) // 2 or
|
||||
lower is -int_oo or
|
||||
(unbacked_only and k in self.var_to_val) or
|
||||
not vr.is_int
|
||||
):
|
||||
|
|
@ -4717,21 +4723,6 @@ class ShapeEnv:
|
|||
if a in self.var_to_range:
|
||||
src_bound = self.var_to_range[a]
|
||||
|
||||
# If you have x in [2, maxint], then 2*x in [4, 2*maxint].
|
||||
# But we don't really care that the max bound says we can
|
||||
# go beyond the maximum integer size, because we aren't
|
||||
# using bigints anyway. Arguably, ValueRanges should know
|
||||
# to do this truncation automaticaly (to avoid doing
|
||||
# bigint compute in range analysis), but right now it doesn't
|
||||
# so we need to get rid of some unnecessary precision.
|
||||
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
|
||||
|
||||
def issubset(x, y):
|
||||
if x.is_int and y.is_int:
|
||||
return (x & int_range).issubset(y & int_range)
|
||||
else:
|
||||
return x.issubset(y)
|
||||
|
||||
# First, refine the value range of a based on the computed value range
|
||||
# of tgt. This is always OK to do, even if we decide not to do the
|
||||
# substitution in the end. This might be a no-op, if a already has
|
||||
|
|
@ -4744,7 +4735,7 @@ class ShapeEnv:
|
|||
# - the source bound non-trivially improves over what we get out of
|
||||
# the existing bounds.
|
||||
# - the replacement is univariate and we can invert the tgt expression
|
||||
if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1:
|
||||
if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
|
||||
b = next(iter(tgt.free_symbols))
|
||||
# Try to invert the equality
|
||||
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
|
||||
|
|
@ -4759,7 +4750,7 @@ class ShapeEnv:
|
|||
b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper))
|
||||
self._update_var_to_range(b, b_bound)
|
||||
tgt_bound = self.bound_sympy(tgt)
|
||||
assert issubset(tgt_bound, src_bound)
|
||||
assert tgt_bound.issubset(src_bound)
|
||||
|
||||
# TODO: Should we propagate size-like-ness?
|
||||
#
|
||||
|
|
@ -4797,13 +4788,13 @@ class ShapeEnv:
|
|||
# - If the variable is unbacked, only substitute if the substitution
|
||||
# would preserve the bounds also under size-like-ness conditions.
|
||||
|
||||
if not issubset(tgt_bound, src_bound):
|
||||
if not tgt_bound.issubset(src_bound):
|
||||
self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound)
|
||||
return
|
||||
elif a in self.size_like:
|
||||
tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
|
||||
src_bound_so = self.bound_sympy(a, size_oblivious=True)
|
||||
if not issubset(tgt_bound_so, src_bound_so):
|
||||
if not tgt_bound_so.issubset(src_bound_so):
|
||||
self.log.debug("skipped set_replacement %s = %s (%s) "
|
||||
"[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so)
|
||||
return
|
||||
|
|
@ -4888,6 +4879,7 @@ class ShapeEnv:
|
|||
has_only_ephemeral_sources = (
|
||||
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
|
||||
)
|
||||
# NB: size_hint is int, not sympy.Expr, do not use int_oo here
|
||||
size = self.size_hint(x, allow_none=True) or sys.maxsize
|
||||
name = x.name
|
||||
# 1 puts ephemeral sourced symbols first when sorting in reverse
|
||||
|
|
@ -4984,15 +4976,12 @@ class ShapeEnv:
|
|||
return
|
||||
|
||||
# See: Note - On 0/1 specialization
|
||||
# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
|
||||
# as a sentinel sometimes. Your sizevar isn't going to be
|
||||
# anywhere near the max 64-bit integer anyway.
|
||||
def _default_value_range(self) -> ValueRanges:
|
||||
lower = 2 if self.specialize_zero_one else 0
|
||||
return ValueRanges(lower, sys.maxsize - 1)
|
||||
return ValueRanges(lower, int_oo)
|
||||
|
||||
def _default_unspecified_value_range(self) -> ValueRanges:
|
||||
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
|
||||
return ValueRanges(-int_oo, int_oo)
|
||||
|
||||
@_lru_cache
|
||||
def _simplify_floor_div(self, expr):
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def insert_deferred_runtime_asserts(
|
|||
):
|
||||
assert len(node.args) == 1
|
||||
nodes_that_already_have_sym_constraint_range.add(
|
||||
(node.args[0], node.kwargs["min"], node.kwargs["max"])
|
||||
(node.args[0], node.kwargs.get("min"), node.kwargs.get("max"))
|
||||
)
|
||||
if (
|
||||
node.op == "call_function"
|
||||
|
|
@ -86,6 +86,7 @@ def insert_deferred_runtime_asserts(
|
|||
InnerTensorKey,
|
||||
)
|
||||
from torch.utils._sympy.interp import sympy_interp
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.reference import PythonReferenceAnalysis
|
||||
|
||||
# TODO: Request simplification on runtime asserts before emitting them
|
||||
|
|
@ -367,6 +368,8 @@ def insert_deferred_runtime_asserts(
|
|||
# (refinement should not be necessary once runtime
|
||||
# asserts cause refinement, but that's NYI)
|
||||
def convert(s):
|
||||
if s in (int_oo, -int_oo):
|
||||
return None
|
||||
try:
|
||||
return int(s)
|
||||
except TypeError:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import sys
|
|||
import sympy
|
||||
from sympy import S
|
||||
|
||||
from .numbers import int_oo
|
||||
|
||||
__all__ = [
|
||||
"FloorDiv",
|
||||
"ModularIndexing",
|
||||
|
|
@ -101,6 +103,15 @@ class FloorDiv(sympy.Function):
|
|||
# makes it difficult to check the types.
|
||||
if divisor.is_zero:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in (
|
||||
int_oo,
|
||||
-int_oo,
|
||||
sympy.oo,
|
||||
-sympy.oo,
|
||||
):
|
||||
return sympy.nan
|
||||
if base is sympy.nan or divisor is sympy.nan:
|
||||
return sympy.nan
|
||||
|
||||
if base.is_zero:
|
||||
return sympy.S.Zero
|
||||
|
|
@ -108,6 +119,23 @@ class FloorDiv(sympy.Function):
|
|||
return base
|
||||
if base.is_integer and divisor == -1:
|
||||
return sympy.Mul(base, -1)
|
||||
if (
|
||||
isinstance(base, sympy.Number)
|
||||
and isinstance(divisor, sympy.Number)
|
||||
and (
|
||||
base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
|
||||
or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
|
||||
)
|
||||
):
|
||||
r = float(base) / float(divisor)
|
||||
if r == math.inf:
|
||||
return int_oo
|
||||
elif r == -math.inf:
|
||||
return -int_oo
|
||||
elif math.isnan(r):
|
||||
return sympy.nan
|
||||
else:
|
||||
return sympy.Integer(math.floor(r))
|
||||
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
|
||||
return sympy.Integer(int(base) // int(divisor))
|
||||
if isinstance(base, FloorDiv):
|
||||
|
|
@ -353,10 +381,10 @@ class CeilToInt(sympy.Function):
|
|||
@classmethod
|
||||
def eval(cls, number):
|
||||
# assert number.is_integer is not True, number
|
||||
if number == sympy.oo:
|
||||
return sympy.Integer(sys.maxsize - 1)
|
||||
if number == -sympy.oo:
|
||||
return sympy.Integer(-sys.maxsize - 1)
|
||||
if number in (sympy.oo, int_oo):
|
||||
return int_oo
|
||||
if number in (-sympy.oo, -int_oo):
|
||||
return -int_oo
|
||||
if isinstance(number, sympy.Number):
|
||||
return sympy.Integer(math.ceil(float(number)))
|
||||
|
||||
|
|
@ -367,10 +395,10 @@ class FloorToInt(sympy.Function):
|
|||
@classmethod
|
||||
def eval(cls, number):
|
||||
# assert number.is_integer is not True, number
|
||||
if number == sympy.oo:
|
||||
return sympy.Integer(sys.maxsize - 1)
|
||||
if number == -sympy.oo:
|
||||
return sympy.Integer(-sys.maxsize - 1)
|
||||
if number in (sympy.oo, int_oo):
|
||||
return int_oo
|
||||
if number in (-sympy.oo, int_oo):
|
||||
return -int_oo
|
||||
if isinstance(number, sympy.Number):
|
||||
return sympy.Integer(math.floor(float(number)))
|
||||
|
||||
|
|
@ -419,6 +447,7 @@ def safe_pow(base, exp):
|
|||
return sign * _safe_pow(base, exp)
|
||||
|
||||
|
||||
# Prevent people from overflowing pow
|
||||
def _safe_pow(base, exponent):
|
||||
if exponent < 0:
|
||||
raise ValueError("Exponent must be non-negative.")
|
||||
|
|
@ -427,17 +456,20 @@ def _safe_pow(base, exponent):
|
|||
return 1
|
||||
|
||||
half_exp = safe_pow(base, exponent // 2)
|
||||
if half_exp > sys.maxsize - 1:
|
||||
return sys.maxsize - 1
|
||||
if half_exp is int_oo:
|
||||
return int_oo
|
||||
|
||||
# TODO: microoptimization is to avoid overflowing into arbitrary precision
|
||||
# and detect overflow prior to doing operations
|
||||
|
||||
result = half_exp * half_exp
|
||||
if result > sys.maxsize - 1:
|
||||
return sys.maxsize - 1
|
||||
if result > sys.maxsize:
|
||||
return int_oo
|
||||
|
||||
if exponent % 2 == 1:
|
||||
result *= base
|
||||
if result > sys.maxsize - 1:
|
||||
return sys.maxsize - 1
|
||||
if result > sys.maxsize:
|
||||
return int_oo
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -447,14 +479,20 @@ class PowByNatural(sympy.Function):
|
|||
|
||||
@classmethod
|
||||
def eval(cls, base, exp):
|
||||
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
|
||||
return sympy.Integer(safe_pow(base, exp))
|
||||
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
|
||||
r = safe_pow(base, exp)
|
||||
if r in (-int_oo, int_oo):
|
||||
return r
|
||||
return sympy.Integer(r)
|
||||
if isinstance(exp, sympy.Integer):
|
||||
# Translate power into iterated multiplication
|
||||
r = sympy.Integer(1)
|
||||
for _ in range(int(exp)):
|
||||
r *= base
|
||||
return r
|
||||
# Rely on regular sympy Pow for this (note that iterated
|
||||
# multiplication turns into a Pow anyway, you can't escape!!)
|
||||
return sympy.Pow(base, exp)
|
||||
if exp in (int_oo, sympy.oo):
|
||||
if base.is_nonnegative:
|
||||
return int_oo
|
||||
elif base.is_negative:
|
||||
return sympy.zoo # this is apparently what (-2)**sympy.oo does
|
||||
# NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
|
||||
# is a natural number if we do
|
||||
|
||||
|
|
@ -467,6 +505,11 @@ class FloatPow(sympy.Function):
|
|||
|
||||
@classmethod
|
||||
def eval(cls, base, exp):
|
||||
# NB: These test sympy.Number, not sympy.Float, because:
|
||||
# - Sometimes we may have sympy.oo or int_oo, and that's not a Float
|
||||
# (but coerces to math.Inf)
|
||||
# - Sometimes Float(0.0) will unpredictably decay to Integer(0),
|
||||
# but we should still accept it in floatey contexts
|
||||
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
|
||||
return sympy.Float(float(base) ** float(exp))
|
||||
# NB: do not do any nontrivial reasoning
|
||||
|
|
@ -510,7 +553,18 @@ class IntTrueDiv(sympy.Function):
|
|||
if divisor.is_zero:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
|
||||
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
|
||||
if (
|
||||
isinstance(base, sympy.Number)
|
||||
and isinstance(divisor, sympy.Number)
|
||||
and (
|
||||
base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
|
||||
or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
|
||||
)
|
||||
):
|
||||
# Don't have to worry about precision here, you're getting zero or
|
||||
# inf from the division
|
||||
return sympy.Float(float(base) / float(divisor))
|
||||
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
|
||||
return sympy.Float(int(base) / int(divisor))
|
||||
|
||||
|
||||
|
|
@ -567,10 +621,10 @@ class TruncToInt(sympy.Function):
|
|||
@classmethod
|
||||
def eval(cls, number):
|
||||
# assert number.is_integer is not True, number
|
||||
if number == sympy.oo:
|
||||
return sympy.Integer(sys.maxsize - 1)
|
||||
if number == -sympy.oo:
|
||||
return sympy.Integer(-sys.maxsize - 1)
|
||||
if number in (sympy.oo, int_oo):
|
||||
return int_oo
|
||||
if number in (-sympy.oo, -int_oo):
|
||||
return -int_oo
|
||||
if isinstance(number, sympy.Number):
|
||||
return sympy.Integer(math.trunc(float(number)))
|
||||
|
||||
|
|
@ -583,7 +637,11 @@ class RoundToInt(sympy.Function):
|
|||
def eval(cls, number):
|
||||
# assert number.is_integer is not True, number
|
||||
|
||||
if isinstance(number, sympy.Float):
|
||||
if number is sympy.oo:
|
||||
return int_oo
|
||||
if number is -sympy.oo:
|
||||
return -int_oo
|
||||
if isinstance(number, sympy.Number):
|
||||
return sympy.Integer(round(float(number), 0))
|
||||
|
||||
|
||||
|
|
@ -610,7 +668,7 @@ class RoundDecimal(sympy.Function):
|
|||
def eval(cls, number, ndigits):
|
||||
# assert number.is_integer is not True, number
|
||||
|
||||
if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer):
|
||||
if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
|
||||
return sympy.Float(round(float(number), int(ndigits)))
|
||||
|
||||
|
||||
|
|
@ -625,6 +683,10 @@ class ToFloat(sympy.Function):
|
|||
|
||||
if isinstance(number, sympy.Integer):
|
||||
return sympy.Float(int(number))
|
||||
if number is int_oo:
|
||||
return sympy.oo
|
||||
if number is -int_oo:
|
||||
return -sympy.oo
|
||||
|
||||
|
||||
def make_opaque_unary_fn(name):
|
||||
|
|
@ -655,7 +717,11 @@ def make_opaque_unary_fn(name):
|
|||
# weird objects but ask silly questions, get silly answers
|
||||
except OverflowError:
|
||||
return getattr(sympy, name)(a)
|
||||
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo]:
|
||||
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
|
||||
if a is int_oo:
|
||||
a = sympy.oo
|
||||
if a is -int_oo:
|
||||
a = -sympy.oo
|
||||
return getattr(sympy, name)(a)
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
|
|||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import sympy
|
||||
|
|
@ -37,6 +38,9 @@ from .functions import (
|
|||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: Dedupe this with SYMPY_INTERP
|
||||
|
||||
|
||||
|
|
@ -157,11 +161,18 @@ def sympy_interp(
|
|||
else:
|
||||
handler_name = handlers()[expr.func]
|
||||
handler = getattr(analysis, handler_name)
|
||||
if handler_name in ASSOCIATIVE_OPS:
|
||||
assert len(args) > 1
|
||||
acc = handler(args[0], args[1])
|
||||
for i in range(2, len(args)):
|
||||
acc = handler(acc, args[i])
|
||||
return acc
|
||||
else:
|
||||
return handler(*args)
|
||||
try:
|
||||
if handler_name in ASSOCIATIVE_OPS:
|
||||
assert len(args) > 1
|
||||
acc = handler(args[0], args[1])
|
||||
for i in range(2, len(args)):
|
||||
acc = handler(acc, args[i])
|
||||
log.debug("%s(%s) -> %s", handler_name, args, acc)
|
||||
return acc
|
||||
else:
|
||||
r = handler(*args)
|
||||
log.debug("%s(%s) -> %s", handler_name, args, r)
|
||||
return r
|
||||
except Exception:
|
||||
log.warning("failed while executing %s(%s)", handler_name, args)
|
||||
raise
|
||||
|
|
|
|||
394
torch/utils/_sympy/numbers.py
Normal file
394
torch/utils/_sympy/numbers.py
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
import mpmath.libmp as mlib # type: ignore[import-untyped]
|
||||
import sympy
|
||||
from sympy import Expr
|
||||
from sympy.core.decorators import _sympifyit
|
||||
from sympy.core.expr import AtomicExpr
|
||||
from sympy.core.numbers import Number
|
||||
from sympy.core.parameters import global_parameters
|
||||
from sympy.core.singleton import S, Singleton
|
||||
|
||||
|
||||
class IntInfinity(Number, metaclass=Singleton):
|
||||
r"""Positive integer infinite quantity.
|
||||
|
||||
Integer infinity is a value in an extended integers which
|
||||
is greater than all other integers. We distinguish it from
|
||||
sympy's existing notion of infinity in that it reports that
|
||||
it is_integer.
|
||||
|
||||
Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
|
||||
or can be imported as ``int_oo``.
|
||||
"""
|
||||
|
||||
# NB: We can't actually mark this as infinite, as integer and infinite are
|
||||
# inconsistent assumptions in sympy. We also report that we are complex,
|
||||
# different from sympy.oo
|
||||
|
||||
is_integer = True
|
||||
is_commutative = True
|
||||
is_number = True
|
||||
is_extended_real = True
|
||||
is_comparable = True
|
||||
is_extended_positive = True
|
||||
is_prime = False
|
||||
|
||||
# Ensure we get dispatched to before plain numbers
|
||||
_op_priority = 100.0
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls):
|
||||
return AtomicExpr.__new__(cls)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return "int_oo"
|
||||
|
||||
def _eval_subs(self, old, new):
|
||||
if self == old:
|
||||
return new
|
||||
|
||||
# We could do these, not sure about it
|
||||
"""
|
||||
def _eval_evalf(self, prec=None):
|
||||
return Float('inf')
|
||||
|
||||
def evalf(self, prec=None, **options):
|
||||
return self._eval_evalf(prec)
|
||||
"""
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __add__(self, other):
|
||||
if isinstance(other, Number) and global_parameters.evaluate:
|
||||
if other is S.NegativeInfinity:
|
||||
return S.NegativeInfinity
|
||||
if other in (S.NegativeIntInfinity, S.NaN):
|
||||
return S.NaN
|
||||
return self
|
||||
return Number.__add__(self, other)
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __sub__(self, other):
|
||||
if isinstance(other, Number) and global_parameters.evaluate:
|
||||
if other is S.Infinity:
|
||||
return S.NegativeInfinity
|
||||
if other in (S.IntInfinity, S.NaN):
|
||||
return S.NaN
|
||||
return self
|
||||
return Number.__sub__(self, other)
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __rsub__(self, other):
|
||||
return (-self).__add__(other)
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, Number) and global_parameters.evaluate:
|
||||
if other.is_zero or other is S.NaN:
|
||||
return S.NaN
|
||||
if other.is_extended_positive:
|
||||
return self
|
||||
return S.NegativeIntInfinity
|
||||
return Number.__mul__(self, other)
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, Number) and global_parameters.evaluate:
|
||||
if other in (
|
||||
S.Infinity,
|
||||
S.IntInfinity,
|
||||
S.NegativeInfinity,
|
||||
S.NegativeIntInfinity,
|
||||
S.NaN,
|
||||
):
|
||||
return S.NaN
|
||||
if other.is_extended_nonnegative:
|
||||
return S.Infinity # truediv produces float
|
||||
return S.NegativeInfinity # truediv produces float
|
||||
return Number.__truediv__(self, other)
|
||||
|
||||
def __abs__(self):
|
||||
return S.IntInfinity
|
||||
|
||||
def __neg__(self):
|
||||
return S.NegativeIntInfinity
|
||||
|
||||
def _eval_power(self, expt):
|
||||
if expt.is_extended_positive:
|
||||
return S.IntInfinity
|
||||
if expt.is_extended_negative:
|
||||
return S.Zero
|
||||
if expt is S.NaN:
|
||||
return S.NaN
|
||||
if expt is S.ComplexInfinity:
|
||||
return S.NaN
|
||||
if expt.is_extended_real is False and expt.is_number:
|
||||
from sympy.functions.elementary.complexes import re
|
||||
|
||||
expt_real = re(expt)
|
||||
if expt_real.is_positive:
|
||||
return S.ComplexInfinity
|
||||
if expt_real.is_negative:
|
||||
return S.Zero
|
||||
if expt_real.is_zero:
|
||||
return S.NaN
|
||||
|
||||
return self ** expt.evalf()
|
||||
|
||||
def _as_mpf_val(self, prec):
|
||||
return mlib.finf
|
||||
|
||||
def __hash__(self):
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other):
|
||||
return other is S.IntInfinity
|
||||
|
||||
def __ne__(self, other):
|
||||
return other is not S.IntInfinity
|
||||
|
||||
def __gt__(self, other):
|
||||
if other is S.Infinity:
|
||||
return sympy.false # sympy.oo > int_oo
|
||||
elif other is S.IntInfinity:
|
||||
return sympy.false # consistency with sympy.oo
|
||||
else:
|
||||
return sympy.true
|
||||
|
||||
def __ge__(self, other):
|
||||
if other is S.Infinity:
|
||||
return sympy.false # sympy.oo > int_oo
|
||||
elif other is S.IntInfinity:
|
||||
return sympy.true # consistency with sympy.oo
|
||||
else:
|
||||
return sympy.true
|
||||
|
||||
def __lt__(self, other):
|
||||
if other is S.Infinity:
|
||||
return sympy.true # sympy.oo > int_oo
|
||||
elif other is S.IntInfinity:
|
||||
return sympy.false # consistency with sympy.oo
|
||||
else:
|
||||
return sympy.false
|
||||
|
||||
def __le__(self, other):
|
||||
if other is S.Infinity:
|
||||
return sympy.true # sympy.oo > int_oo
|
||||
elif other is S.IntInfinity:
|
||||
return sympy.true # consistency with sympy.oo
|
||||
else:
|
||||
return sympy.false
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __mod__(self, other):
|
||||
if not isinstance(other, Expr):
|
||||
return NotImplemented
|
||||
return S.NaN
|
||||
|
||||
__rmod__ = __mod__
|
||||
|
||||
def floor(self):
|
||||
return self
|
||||
|
||||
def ceiling(self):
|
||||
return self
|
||||
|
||||
|
||||
int_oo = S.IntInfinity
|
||||
|
||||
|
||||
class NegativeIntInfinity(Number, metaclass=Singleton):
|
||||
"""Negative integer infinite quantity.
|
||||
|
||||
NegativeInfinity is a singleton, and can be accessed
|
||||
by ``S.NegativeInfinity``.
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
IntInfinity
|
||||
"""
|
||||
|
||||
# Ensure we get dispatched to before plain numbers
|
||||
_op_priority = 100.0
|
||||
|
||||
is_integer = True
|
||||
is_extended_real = True
|
||||
is_commutative = True
|
||||
is_comparable = True
|
||||
is_extended_negative = True
|
||||
is_number = True
|
||||
is_prime = False
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls):
|
||||
return AtomicExpr.__new__(cls)
|
||||
|
||||
def _eval_subs(self, old, new):
|
||||
if self == old:
|
||||
return new
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return "-int_oo"
|
||||
|
||||
"""
|
||||
def _eval_evalf(self, prec=None):
|
||||
return Float('-inf')
|
||||
|
||||
def evalf(self, prec=None, **options):
|
||||
return self._eval_evalf(prec)
|
||||
"""
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __add__(self, other):
|
||||
if isinstance(other, Number) and global_parameters.evaluate:
|
||||
if other is S.Infinity:
|
||||
return S.Infinity
|
||||
if other in (S.IntInfinity, S.NaN):
|
||||
return S.NaN
|
||||
return self
|
||||
return Number.__add__(self, other)
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __sub__(self, other):
|
||||
if isinstance(other, Number) and global_parameters.evaluate:
|
||||
if other is S.NegativeInfinity:
|
||||
return S.Infinity
|
||||
if other in (S.NegativeIntInfinity, S.NaN):
|
||||
return S.NaN
|
||||
return self
|
||||
return Number.__sub__(self, other)
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __rsub__(self, other):
|
||||
return (-self).__add__(other)
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, Number) and global_parameters.evaluate:
|
||||
if other.is_zero or other is S.NaN:
|
||||
return S.NaN
|
||||
if other.is_extended_positive:
|
||||
return self
|
||||
return S.IntInfinity
|
||||
return Number.__mul__(self, other)
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, Number) and global_parameters.evaluate:
|
||||
if other in (
|
||||
S.Infinity,
|
||||
S.IntInfinity,
|
||||
S.NegativeInfinity,
|
||||
S.NegativeIntInfinity,
|
||||
S.NaN,
|
||||
):
|
||||
return S.NaN
|
||||
if other.is_extended_nonnegative:
|
||||
return self
|
||||
return S.Infinity # truediv returns float
|
||||
return Number.__truediv__(self, other)
|
||||
|
||||
def __abs__(self):
|
||||
return S.IntInfinity
|
||||
|
||||
def __neg__(self):
|
||||
return S.IntInfinity
|
||||
|
||||
def _eval_power(self, expt):
|
||||
if expt.is_number:
|
||||
if expt in (
|
||||
S.NaN,
|
||||
S.Infinity,
|
||||
S.NegativeInfinity,
|
||||
S.IntInfinity,
|
||||
S.NegativeIntInfinity,
|
||||
):
|
||||
return S.NaN
|
||||
|
||||
if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
|
||||
if expt.is_odd:
|
||||
return S.NegativeIntInfinity
|
||||
else:
|
||||
return S.IntInfinity
|
||||
|
||||
inf_part = S.IntInfinity**expt
|
||||
s_part = S.NegativeOne**expt
|
||||
if inf_part == 0 and s_part.is_finite:
|
||||
return inf_part
|
||||
if (
|
||||
inf_part is S.ComplexInfinity
|
||||
and s_part.is_finite
|
||||
and not s_part.is_zero
|
||||
):
|
||||
return S.ComplexInfinity
|
||||
return s_part * inf_part
|
||||
|
||||
def _as_mpf_val(self, prec):
|
||||
return mlib.fninf
|
||||
|
||||
def __hash__(self):
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other):
|
||||
return other is S.NegativeIntInfinity
|
||||
|
||||
def __ne__(self, other):
|
||||
return other is not S.NegativeIntInfinity
|
||||
|
||||
def __gt__(self, other):
|
||||
if other is S.NegativeInfinity:
|
||||
return sympy.true # -sympy.oo < -int_oo
|
||||
elif other is S.NegativeIntInfinity:
|
||||
return sympy.false # consistency with sympy.oo
|
||||
else:
|
||||
return sympy.false
|
||||
|
||||
def __ge__(self, other):
|
||||
if other is S.NegativeInfinity:
|
||||
return sympy.true # -sympy.oo < -int_oo
|
||||
elif other is S.NegativeIntInfinity:
|
||||
return sympy.true # consistency with sympy.oo
|
||||
else:
|
||||
return sympy.false
|
||||
|
||||
def __lt__(self, other):
|
||||
if other is S.NegativeInfinity:
|
||||
return sympy.false # -sympy.oo < -int_oo
|
||||
elif other is S.NegativeIntInfinity:
|
||||
return sympy.false # consistency with sympy.oo
|
||||
else:
|
||||
return sympy.true
|
||||
|
||||
def __le__(self, other):
|
||||
if other is S.NegativeInfinity:
|
||||
return sympy.false # -sympy.oo < -int_oo
|
||||
elif other is S.NegativeIntInfinity:
|
||||
return sympy.true # consistency with sympy.oo
|
||||
else:
|
||||
return sympy.true
|
||||
|
||||
@_sympifyit("other", NotImplemented)
|
||||
def __mod__(self, other):
|
||||
if not isinstance(other, Expr):
|
||||
return NotImplemented
|
||||
return S.NaN
|
||||
|
||||
__rmod__ = __mod__
|
||||
|
||||
def floor(self):
|
||||
return self
|
||||
|
||||
def ceiling(self):
|
||||
return self
|
||||
|
||||
def as_powers_dict(self):
|
||||
return {S.NegativeOne: 1, S.IntInfinity: 1}
|
||||
|
|
@ -6,7 +6,6 @@ import itertools
|
|||
import logging
|
||||
import math
|
||||
import operator
|
||||
import sys
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
|
|
@ -24,6 +23,7 @@ import sympy
|
|||
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
|
||||
|
||||
import torch
|
||||
from torch._logging import LazyString
|
||||
|
||||
from torch._prims_common import dtype_to_type
|
||||
from .functions import (
|
||||
|
|
@ -43,6 +43,7 @@ from .functions import (
|
|||
TruncToInt,
|
||||
)
|
||||
from .interp import sympy_interp
|
||||
from .numbers import int_oo, IntInfinity, NegativeIntInfinity
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -168,7 +169,10 @@ class ValueRanges(Generic[_T]):
|
|||
self,
|
||||
"is_int",
|
||||
not self.is_bool
|
||||
and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)),
|
||||
and (
|
||||
isinstance(lower, (sympy.Integer, NegativeIntInfinity))
|
||||
or isinstance(upper, (sympy.Integer, IntInfinity))
|
||||
),
|
||||
)
|
||||
"""
|
||||
# This assert is just impossible right now, too many sympy bugs
|
||||
|
|
@ -265,11 +269,14 @@ class ValueRanges(Generic[_T]):
|
|||
def is_singleton(self) -> bool:
|
||||
return self.lower == self.upper
|
||||
|
||||
# TODO: this doesn't work with bools but arguably it should
|
||||
@staticmethod
|
||||
def unknown() -> ValueRanges[sympy.Expr]:
|
||||
return ValueRanges(-sympy.oo, sympy.oo)
|
||||
|
||||
@staticmethod
|
||||
def unknown_int() -> ValueRanges[sympy.Expr]:
|
||||
return ValueRanges(-int_oo, int_oo)
|
||||
|
||||
@staticmethod
|
||||
def unknown_bool() -> ValueRanges[SympyBoolean]:
|
||||
return ValueRanges(sympy.false, sympy.true)
|
||||
|
|
@ -401,7 +408,7 @@ class SymPyValueRangeAnalysis:
|
|||
elif dtype.is_floating_point:
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
|
||||
return ValueRanges(-int_oo, int_oo)
|
||||
|
||||
if is_python:
|
||||
type_ = dtype_to_type(dtype)
|
||||
|
|
@ -424,6 +431,10 @@ class SymPyValueRangeAnalysis:
|
|||
def to_dtype(a, dtype, src_dtype=None):
|
||||
if dtype == torch.float64:
|
||||
return ValueRanges.increasing_map(a, ToFloat)
|
||||
elif dtype == torch.bool:
|
||||
return ValueRanges.unknown_bool()
|
||||
elif not dtype.is_floating_point:
|
||||
return ValueRanges.unknown_int()
|
||||
return ValueRanges.unknown()
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -515,9 +526,7 @@ class SymPyValueRangeAnalysis:
|
|||
def int_truediv(a, b):
|
||||
a = ValueRanges.wrap(a)
|
||||
b = ValueRanges.wrap(b)
|
||||
if 0 in b or (
|
||||
(-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
|
||||
):
|
||||
if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)):
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(
|
||||
|
|
@ -541,14 +550,17 @@ class SymPyValueRangeAnalysis:
|
|||
def floordiv(a, b):
|
||||
a = ValueRanges.wrap(a)
|
||||
b = ValueRanges.wrap(b)
|
||||
if 0 in b or (
|
||||
# TODO: make this more precise
|
||||
(-sympy.oo in a or sympy.oo in a)
|
||||
or (-sympy.oo in b or sympy.oo in b)
|
||||
):
|
||||
if 0 in b:
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv)
|
||||
products = []
|
||||
for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]):
|
||||
r = FloorDiv(x, y)
|
||||
if r is sympy.nan:
|
||||
products.append((sympy.sign(x) * sympy.sign(y)) * int_oo)
|
||||
else:
|
||||
products.append(r)
|
||||
|
||||
return ValueRanges(min(products), max(products))
|
||||
|
||||
@classmethod
|
||||
def mod(cls, x, y):
|
||||
|
|
@ -564,10 +576,10 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
def c_div(a, b):
|
||||
x = a / b
|
||||
return sympy.Integer(x) if x.is_finite else x
|
||||
return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x
|
||||
|
||||
if 0 in y:
|
||||
return ValueRanges.unknown()
|
||||
return ValueRanges.unknown_int()
|
||||
elif y.is_singleton():
|
||||
y_val = abs(y.lower)
|
||||
# If it wraps, we need to take the whole interval
|
||||
|
|
@ -597,7 +609,7 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@classmethod
|
||||
def is_non_overlapping_and_dense_indicator(cls, *args):
|
||||
return ValueRanges.unknown() # TODO: type here is wrong
|
||||
return ValueRanges.unknown_int()
|
||||
|
||||
@classmethod
|
||||
def pow_by_natural(cls, a, b):
|
||||
|
|
@ -611,7 +623,7 @@ class SymPyValueRangeAnalysis:
|
|||
# to replacements, so don't assert it, but DO clamp it to prevent
|
||||
# degenerate problems
|
||||
return ValueRanges.coordinatewise_increasing_map(
|
||||
a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural
|
||||
a, b & ValueRanges(0, int_oo), PowByNatural
|
||||
)
|
||||
elif b.is_singleton():
|
||||
if b.lower % 2 == 0:
|
||||
|
|
@ -939,6 +951,8 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
|||
if dtype.is_floating_point:
|
||||
return sympy.Float(x)
|
||||
else:
|
||||
if x in (int_oo, -int_oo):
|
||||
return x
|
||||
try:
|
||||
return sympy.Integer(x)
|
||||
except TypeError:
|
||||
|
|
@ -986,7 +1000,18 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
|||
def bound_sympy(
|
||||
expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None
|
||||
) -> ValueRanges:
|
||||
log.debug("bound_sympy(%s, %s)", expr, ranges)
|
||||
log.debug(
|
||||
"bound_sympy(%s)%s",
|
||||
expr,
|
||||
LazyString(
|
||||
lambda: "\n"
|
||||
+ "\n".join(
|
||||
f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
|
||||
)
|
||||
if ranges
|
||||
else ""
|
||||
),
|
||||
)
|
||||
if isinstance(expr, sympy.Number):
|
||||
return ValueRanges.wrap(expr)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user