Introduce int_oo (#127693)

In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range.

After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better.

But especially after #126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation.

The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments.

Fixes https://github.com/pytorch/pytorch/issues/127396

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127693
Approved by: https://github.com/lezcano
ghstack dependencies: #126905
This commit is contained in:
Edward Z. Yang 2024-06-09 09:48:47 -07:00 committed by PyTorch MergeBot
parent db2fa7b827
commit 9cab5987bd
19 changed files with 746 additions and 145 deletions

View File

@ -253,7 +253,6 @@ Target Expressions:
==> (>= 0 s1)
==> (>= 0 s2)
==> (>= 0 s3)
==> (>= 9223372036854775806 s0)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
@ -287,14 +286,14 @@ Failure occurred while running node:
Model:
==> L['shape'][0]: 1
==> L['shape'][1]: 1
==> L['shape'][2]: 2
==> L['shape'][2]: 0
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 1
==> s2: 1
==> s3: 2
==> s3: 0
Assertions:
==> (== 0 L['x'].storage_offset())
@ -318,10 +317,6 @@ Target Expressions:
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 9223372036854775806 s0)
==> (>= 9223372036854775807 s1)
==> (>= 9223372036854775807 s2)
==> (>= 9223372036854775807 s3)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",

View File

@ -3473,7 +3473,6 @@ class GraphModule(torch.nn.Module):
]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
]
test_symbool_guards(
f,

View File

@ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match:
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Right: {0: 0, 1: 1}
==> var_to_range: values don't match.
> Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Right: {}
==> var_to_sources: values don't match.
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
@ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match:
> Left: 2
> Right: 0
==> var_to_range: values don't match.
> Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Right: {}
""",
)
@ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match:
> Left: {s0: 3}
> Right: {}
==> var_to_range: values don't match.
> Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
""",
)
self._replay_and_check(main)
@ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match:
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match.
> Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
""",
)
self._replay_and_check(main)

View File

@ -201,6 +201,19 @@ class TestDynamismExpression(TestCase):
dynamic_shapes={"x": {0: dim_x}},
)
def test_export_slice_maxsize(self):
class Slice(torch.nn.Module):
def forward(self, *args):
return torch.ops.aten.slice.Tensor(*args)
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
torch.export.export(
Slice(),
inp,
dynamic_shapes=dynamic_shapes,
)
def test_export_constraints_error(self):
class ConflictingConstraints(torch.nn.Module):
def forward(self, x):
@ -5183,7 +5196,7 @@ def forward(self, x, y):
}
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
def test_disable_forced_specializations(self):
def test_disable_forced_specializations_ok(self):
# check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags
# both behave correctly, avoiding forced specializations and deferring to runtime.
# case 1: modulo guards

View File

@ -633,10 +633,6 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
func, (torch.randn(3, 4),)
)
@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}.",
reason="https://github.com/pytorch/pytorch/issues/112622",
)
def test_operator_with_scalar_output(self):
class Foo(torch.nn.Module):
def forward(self, x, y):

View File

@ -381,6 +381,17 @@ class TestPySymInt(TestCase):
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
def test_floordiv_static(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
# This was extracted from
# python test/inductor/test_cuda_cpp_wrapper.py -k
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
bool(s0 % 2 == 0)
bool(s0 % (s0 // 2) == 0)
bool(2 * (s0 // 2) == s0)
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)

View File

@ -1201,7 +1201,9 @@ def forward(self, x_1):
batch_size = 4
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
self.assertEqual(len(gm.shape_env.guards), 0)
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
# 1 ok)
self.assertEqual(len(gm.shape_env.guards), 1)
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_cpu_scalar_cuda(self):

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: pt2"]
import itertools
import math
import sys
import sympy
@ -19,6 +20,7 @@ from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools
import torch.fx as fx
@ -122,6 +124,74 @@ def generate_range(vals):
yield ValueRanges(a1, a2)
class TestNumbers(TestCase):
def test_int_infinity(self):
self.assertIsInstance(int_oo, IntInfinity)
self.assertIsInstance(-int_oo, NegativeIntInfinity)
self.assertTrue(int_oo.is_integer)
# is tests here are for singleton-ness, don't use it for comparisons
# against numbers
self.assertIs(int_oo + int_oo, int_oo)
self.assertIs(int_oo + 1, int_oo)
self.assertIs(int_oo - 1, int_oo)
self.assertIs(-int_oo - 1, -int_oo)
self.assertIs(-int_oo + 1, -int_oo)
self.assertIs(-int_oo + (-int_oo), -int_oo)
self.assertIs(-int_oo - int_oo, -int_oo)
self.assertIs(1 + int_oo, int_oo)
self.assertIs(1 - int_oo, -int_oo)
self.assertIs(int_oo * int_oo, int_oo)
self.assertIs(2 * int_oo, int_oo)
self.assertIs(int_oo * 2, int_oo)
self.assertIs(-1 * int_oo, -int_oo)
self.assertIs(-int_oo * int_oo, -int_oo)
self.assertIs(2 * -int_oo, -int_oo)
self.assertIs(-int_oo * 2, -int_oo)
self.assertIs(-1 * -int_oo, int_oo)
self.assertIs(int_oo / 2, sympy.oo)
self.assertIs(-(-int_oo), int_oo) # noqa: B002
self.assertIs(abs(int_oo), int_oo)
self.assertIs(abs(-int_oo), int_oo)
self.assertIs(int_oo ** 2, int_oo)
self.assertIs((-int_oo) ** 2, int_oo)
self.assertIs((-int_oo) ** 3, -int_oo)
self.assertEqual(int_oo ** -1, 0)
self.assertEqual((-int_oo) ** -1, 0)
self.assertIs(int_oo ** int_oo, int_oo)
self.assertTrue(int_oo == int_oo)
self.assertFalse(int_oo != int_oo)
self.assertTrue(-int_oo == -int_oo)
self.assertFalse(int_oo == 2)
self.assertTrue(int_oo != 2)
self.assertFalse(int_oo == sys.maxsize)
self.assertTrue(int_oo >= sys.maxsize)
self.assertTrue(int_oo >= 2)
self.assertTrue(int_oo >= -int_oo)
def test_relation(self):
self.assertIs(sympy.Add(2, int_oo), int_oo)
self.assertFalse(-int_oo > 2)
def test_lt_self(self):
self.assertFalse(int_oo < int_oo)
self.assertIs(min(-int_oo, -4), -int_oo)
self.assertIs(min(-int_oo, -int_oo), -int_oo)
def test_float_cast(self):
self.assertEqual(float(int_oo), math.inf)
self.assertEqual(float(-int_oo), -math.inf)
def test_mixed_oo_int_oo(self):
# Arbitrary choice
self.assertTrue(int_oo < sympy.oo)
self.assertFalse(int_oo > sympy.oo)
self.assertTrue(sympy.oo > int_oo)
self.assertFalse(sympy.oo < int_oo)
self.assertIs(max(int_oo, sympy.oo), sympy.oo)
self.assertTrue(-int_oo > -sympy.oo)
self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)
class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
@parametrize("dtype", ("int", "float"))

View File

@ -734,6 +734,11 @@ def slice_forward(
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
)
ndim = self.dim()
if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
@ -760,7 +765,9 @@ def slice_forward(
if end_val < start_val:
end_val = start_val
elif end_val > sizes[dim]:
elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
end_val > sizes[dim]
):
end_val = sizes[dim]
storage_offset = self.storage_offset() + start_val * strides[dim]

View File

@ -10,6 +10,7 @@ import sympy
import torch
import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.passes.infra.pass_base import PassBase, PassResult
@ -23,9 +24,9 @@ class InputDim(NamedTuple):
def _convert_to_int(val):
# Convert simple sympy Integers into concrete int
if val == sympy.oo:
if val in (sympy.oo, int_oo):
return math.inf
if val == -sympy.oo:
if val in (-sympy.oo, -int_oo):
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)

View File

@ -42,6 +42,7 @@ from torch.fx.experimental import symbolic_shapes
from torch.utils import _pytree as pytree
from torch.utils._pytree import treespec_dumps, treespec_loads
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from .schema import ( # type: ignore[attr-defined]
Argument,
@ -321,9 +322,9 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]
def _sympy_int_to_int(val: sympy.Expr, adjust: str):
# Convert simple sympy Integers into concrete int
if val == sympy.oo:
if val in (sympy.oo, int_oo):
return math.inf
if val == -sympy.oo:
if val in (-sympy.oo, -int_oo):
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
@ -346,9 +347,9 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str):
def _int_to_sympy_int(val) -> sympy.Expr:
# Convert concrete int into simple sympy Integers
if val == math.inf:
return sympy.oo
return int_oo
if val == -math.inf:
return -sympy.oo
return -int_oo
return sympy.Integer(val)
@ -1826,7 +1827,7 @@ class GraphModuleDeserializer(metaclass=Final):
self.symbol_name_to_range = {}
if symbol_name_to_range:
for k, vr in symbol_name_to_range.items():
lower = int(vr.lower)
lower = vr.lower
if vr.upper >= 2: # max is >= 2, not sym bool range
lower = max(2, lower)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)

View File

@ -42,6 +42,7 @@ from torch.fx.experimental.symbolic_shapes import (
SymTypes,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._sympy.numbers import int_oo
from . import config, ir
from .codegen.common import (
@ -1427,18 +1428,21 @@ class GraphLowering(torch.fx.Interpreter):
vr = shape_env.var_to_range[i0]
if not shape_env._default_unspecified_value_range().issubset(vr):
def convert(s):
def is_convertible(s):
if s in (int_oo, -int_oo):
return False
try:
return int(s)
int(s)
return True
except TypeError:
return None
return False
if (lower := convert(vr.lower)) is not None:
if is_convertible(vr.lower):
self.register_buffer(
ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"),
set_name=True,
)
if (upper := convert(vr.upper)) is not None:
if is_convertible(vr.upper):
self.register_buffer(
ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"),
set_name=True,

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import builtins
import dataclasses
import inspect
import sys
@ -41,9 +40,11 @@ class _Dim(type):
@staticmethod
def readable(name, min_, max_):
from torch.utils._sympy.numbers import int_oo
if min_ == 2:
min_ = None
if max_ == sys.maxsize - 1:
if max_ == int_oo:
max_ = None
if min_ is None and max_ is None:
return f"Dim('{name}')"
@ -140,6 +141,11 @@ class _DerivedDim(_Dim):
# TODO(avik): use sympy value range analysis instead?
from sympy import Integer
from torch.utils._sympy.numbers import int_oo
if self.root.min is -int_oo: # type: ignore[attr-defined]
return -int_oo # fn not needed cuz increasing
_min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined]
root = self.root # type: ignore[attr-defined]
assert _min_symint >= 0, (
@ -155,6 +161,11 @@ class _DerivedDim(_Dim):
# TODO(avik): use sympy value range analysis instead?
from sympy import Integer
from torch.utils._sympy.numbers import int_oo
if self.root.max is int_oo: # type: ignore[attr-defined]
return int_oo # fn not needed cuz increasing
_max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined]
root = self.root # type: ignore[attr-defined]
assert _max_symint <= sys.maxsize - 1, (
@ -190,8 +201,10 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
Returns:
A type that can be used in dynamic shape specifications for tensors.
"""
from torch.utils._sympy.numbers import int_oo
_min = 0 if min is None else min
_max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1)
_max = int_oo if max is None else max
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
dim = _Dim(name, (int,), {"min": _min, "max": _max})
dim.__module__ = getattr(
@ -269,10 +282,11 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
def _clone_with_range(self, lower=0, upper=None):
# Import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import ValueRanges
if upper is None:
upper = sys.maxsize - 1
upper = int_oo
constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
@ -503,15 +517,14 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
# Import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import ValueRanges
return _create_constraint(
weakref.ref(t),
id(t),
index,
StrictMinMaxConstraint(
vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False
),
StrictMinMaxConstraint(vr=ValueRanges(lower=0, upper=int_oo), warn_only=False),
debug_name=debug_name,
)
@ -725,6 +738,7 @@ def _process_dynamic_shapes(
import sympy
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import ValueRanges
@ -799,7 +813,7 @@ def _process_dynamic_shapes(
constraint = dynamic_dim(tensor, i, debug_name=dim.__name__)
if dim.min != 0:
constraint = constraint >= dim.min
if dim.max != sys.maxsize - 1:
if dim.max != int_oo:
constraint = constraint <= dim.max
return constraint

View File

@ -65,6 +65,7 @@ from torch.utils._sympy.functions import (
FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt
)
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._traceback import format_frame, CapturedTraceback
@ -871,9 +872,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
for N=1.
"""
if min is None:
min = -sys.maxsize - 1
min = -int_oo
if max is None:
max = sys.maxsize - 1
max = int_oo
if max < min:
raise ValueError(
@ -1382,6 +1383,7 @@ SYMPY_INTERP = {
'PythonMod': operator.mod,
'FloorDiv': operator.floordiv,
'TrueDiv': operator.truediv,
'PowByNatural': operator.pow,
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
'floor': math.floor,
'ceiling': math.ceil,
@ -1994,7 +1996,7 @@ class DimConstraints:
(dim.min < 2 and c.get("min", 2) == 2)
or dim.min == c.get("min", 2)
) # let pass if analysis min = 2 and specified min = 0/1
and dim.max == c.get("max", sys.maxsize - 1)
and dim.max == c.get("max", int_oo)
)
# 1) newly introduced roots
@ -2017,7 +2019,7 @@ class DimConstraints:
modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
c_min = c.get("min", 2)
min_ = math.ceil((c_min - remainder) / modulus)
c_max = c.get("max", sys.maxsize - 1)
c_max = c.get("max", int_oo)
max_ = math.floor((c_max - remainder) / modulus)
# create result & dim
results[str(root)] = {"min": min_, "max": max_}
@ -2765,7 +2767,7 @@ class ShapeEnv:
if min is None:
min = 0
if max is None:
max = sys.maxsize - 1
max = int_oo
if max < min:
raise ValueError(
@ -4094,7 +4096,7 @@ class ShapeEnv:
assert sources
bounds = []
if r.lower != -sympy.oo:
if r.lower not in (-sympy.oo, -int_oo):
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
# Only print lower bound in simplified mode if it is not the
@ -4102,14 +4104,7 @@ class ShapeEnv:
if not _simplified or r.lower != self._default_value_range().lower:
bounds.append(str(r.lower))
bounds.append(source_ref(sources[0]))
# NB: This looks like an off-by-one error but it's not: the
# upper bound may be sys.maxsize - 1 because we intentionally
# exclude sys.maxsize from our bounds to deal with direct
# == INT_MAX guards, but it's still dumb to actually test it.
# Note that you can be off by a pretty large constant and it
# won't matter because sizes in practice will be no where near
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
if r.upper not in (sympy.oo, int_oo):
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting
@ -4121,9 +4116,8 @@ class ShapeEnv:
constraints = symbol_to_constraints[symbol]
for c in constraints:
if isinstance(c, StrictMinMaxConstraint):
# NB: By default, we have a restrictive range
# 2 <= s0 <= sys.maxsize - 1. But export users generally
# expect to be able to specify nice ranges like [0, oo]
# TODO: With int_oo, I think this condition is a noop
# now
if not (c.vr & self._default_value_range()).issubset(r):
source = sources[0]
@ -4196,9 +4190,9 @@ class ShapeEnv:
# Reason: '_maybe_evaluate_static' may eliminate guards based on the
# refined value ranges.
for sym, vr in self.var_to_range.items():
if vr.lower != -sympy.oo:
if vr.lower not in (-sympy.oo, -int_oo):
self._add_target_expr(sympy.Le(vr.lower, sym))
if vr.upper != sympy.oo:
if vr.upper not in (sympy.oo, int_oo):
self._add_target_expr(sympy.Le(sym, vr.upper))
# Before validating, populate the input of the validator with the
@ -4330,9 +4324,14 @@ class ShapeEnv:
var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
if size_oblivious:
# Clamp values of size-like variables
# NB: discarding the old upper bound in intentional, per
# https://github.com/pytorch/pytorch/pull/123675
for x in self.size_like & var_to_range.keys():
if var_to_range[x] is not None:
var_to_range[x] = ValueRanges(2, sys.maxsize - 1)
# NB: do NOT set upper to 2 ** 48, we're using this solely
# to determine if we can do size-like replacement, the
# upper bound is irrelevant here
var_to_range[x] = ValueRanges(2, int_oo)
assert var_to_range[x].is_int
return bound_sympy(expr, var_to_range)
@ -4450,18 +4449,25 @@ class ShapeEnv:
vr = self._default_unspecified_value_range()
if size_oblivious and k in self.size_like:
lower = max(2, vr.lower)
# Clamping size-oblivious to some quantity below sys.maxsize
# helps us determine that f(u0) != sys.maxsize, which is a
# test that is looking for sys.maxsize as a sentinel, but you
# don't really want to worry about it for unbacked SymInts.
# This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way.
upper = min(2 ** 48, vr.upper)
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems.
if lower <= vr.upper:
vr = ValueRanges(lower, vr.upper)
if lower <= upper:
vr = ValueRanges(lower, upper)
else:
lower = vr.lower
# Don't do anything if we don't have a nontrivial lower bound
# Also don't do anything if we asked only to simplify unbacked
# SymInt
if (
lower < (-sys.maxsize - 1) // 2 or
lower is -int_oo or
(unbacked_only and k in self.var_to_val) or
not vr.is_int
):
@ -4717,21 +4723,6 @@ class ShapeEnv:
if a in self.var_to_range:
src_bound = self.var_to_range[a]
# If you have x in [2, maxint], then 2*x in [4, 2*maxint].
# But we don't really care that the max bound says we can
# go beyond the maximum integer size, because we aren't
# using bigints anyway. Arguably, ValueRanges should know
# to do this truncation automaticaly (to avoid doing
# bigint compute in range analysis), but right now it doesn't
# so we need to get rid of some unnecessary precision.
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
def issubset(x, y):
if x.is_int and y.is_int:
return (x & int_range).issubset(y & int_range)
else:
return x.issubset(y)
# First, refine the value range of a based on the computed value range
# of tgt. This is always OK to do, even if we decide not to do the
# substitution in the end. This might be a no-op, if a already has
@ -4744,7 +4735,7 @@ class ShapeEnv:
# - the source bound non-trivially improves over what we get out of
# the existing bounds.
# - the replacement is univariate and we can invert the tgt expression
if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1:
if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
b = next(iter(tgt.free_symbols))
# Try to invert the equality
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
@ -4759,7 +4750,7 @@ class ShapeEnv:
b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper))
self._update_var_to_range(b, b_bound)
tgt_bound = self.bound_sympy(tgt)
assert issubset(tgt_bound, src_bound)
assert tgt_bound.issubset(src_bound)
# TODO: Should we propagate size-like-ness?
#
@ -4797,13 +4788,13 @@ class ShapeEnv:
# - If the variable is unbacked, only substitute if the substitution
# would preserve the bounds also under size-like-ness conditions.
if not issubset(tgt_bound, src_bound):
if not tgt_bound.issubset(src_bound):
self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound)
return
elif a in self.size_like:
tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
src_bound_so = self.bound_sympy(a, size_oblivious=True)
if not issubset(tgt_bound_so, src_bound_so):
if not tgt_bound_so.issubset(src_bound_so):
self.log.debug("skipped set_replacement %s = %s (%s) "
"[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so)
return
@ -4888,6 +4879,7 @@ class ShapeEnv:
has_only_ephemeral_sources = (
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
)
# NB: size_hint is int, not sympy.Expr, do not use int_oo here
size = self.size_hint(x, allow_none=True) or sys.maxsize
name = x.name
# 1 puts ephemeral sourced symbols first when sorting in reverse
@ -4984,15 +4976,12 @@ class ShapeEnv:
return
# See: Note - On 0/1 specialization
# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
# as a sentinel sometimes. Your sizevar isn't going to be
# anywhere near the max 64-bit integer anyway.
def _default_value_range(self) -> ValueRanges:
lower = 2 if self.specialize_zero_one else 0
return ValueRanges(lower, sys.maxsize - 1)
return ValueRanges(lower, int_oo)
def _default_unspecified_value_range(self) -> ValueRanges:
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
return ValueRanges(-int_oo, int_oo)
@_lru_cache
def _simplify_floor_div(self, expr):

View File

@ -65,7 +65,7 @@ def insert_deferred_runtime_asserts(
):
assert len(node.args) == 1
nodes_that_already_have_sym_constraint_range.add(
(node.args[0], node.kwargs["min"], node.kwargs["max"])
(node.args[0], node.kwargs.get("min"), node.kwargs.get("max"))
)
if (
node.op == "call_function"
@ -86,6 +86,7 @@ def insert_deferred_runtime_asserts(
InnerTensorKey,
)
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.reference import PythonReferenceAnalysis
# TODO: Request simplification on runtime asserts before emitting them
@ -367,6 +368,8 @@ def insert_deferred_runtime_asserts(
# (refinement should not be necessary once runtime
# asserts cause refinement, but that's NYI)
def convert(s):
if s in (int_oo, -int_oo):
return None
try:
return int(s)
except TypeError:

View File

@ -6,6 +6,8 @@ import sys
import sympy
from sympy import S
from .numbers import int_oo
__all__ = [
"FloorDiv",
"ModularIndexing",
@ -101,6 +103,15 @@ class FloorDiv(sympy.Function):
# makes it difficult to check the types.
if divisor.is_zero:
raise ZeroDivisionError("division by zero")
if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in (
int_oo,
-int_oo,
sympy.oo,
-sympy.oo,
):
return sympy.nan
if base is sympy.nan or divisor is sympy.nan:
return sympy.nan
if base.is_zero:
return sympy.S.Zero
@ -108,6 +119,23 @@ class FloorDiv(sympy.Function):
return base
if base.is_integer and divisor == -1:
return sympy.Mul(base, -1)
if (
isinstance(base, sympy.Number)
and isinstance(divisor, sympy.Number)
and (
base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
)
):
r = float(base) / float(divisor)
if r == math.inf:
return int_oo
elif r == -math.inf:
return -int_oo
elif math.isnan(r):
return sympy.nan
else:
return sympy.Integer(math.floor(r))
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return sympy.Integer(int(base) // int(divisor))
if isinstance(base, FloorDiv):
@ -353,10 +381,10 @@ class CeilToInt(sympy.Function):
@classmethod
def eval(cls, number):
# assert number.is_integer is not True, number
if number == sympy.oo:
return sympy.Integer(sys.maxsize - 1)
if number == -sympy.oo:
return sympy.Integer(-sys.maxsize - 1)
if number in (sympy.oo, int_oo):
return int_oo
if number in (-sympy.oo, -int_oo):
return -int_oo
if isinstance(number, sympy.Number):
return sympy.Integer(math.ceil(float(number)))
@ -367,10 +395,10 @@ class FloorToInt(sympy.Function):
@classmethod
def eval(cls, number):
# assert number.is_integer is not True, number
if number == sympy.oo:
return sympy.Integer(sys.maxsize - 1)
if number == -sympy.oo:
return sympy.Integer(-sys.maxsize - 1)
if number in (sympy.oo, int_oo):
return int_oo
if number in (-sympy.oo, int_oo):
return -int_oo
if isinstance(number, sympy.Number):
return sympy.Integer(math.floor(float(number)))
@ -419,6 +447,7 @@ def safe_pow(base, exp):
return sign * _safe_pow(base, exp)
# Prevent people from overflowing pow
def _safe_pow(base, exponent):
if exponent < 0:
raise ValueError("Exponent must be non-negative.")
@ -427,17 +456,20 @@ def _safe_pow(base, exponent):
return 1
half_exp = safe_pow(base, exponent // 2)
if half_exp > sys.maxsize - 1:
return sys.maxsize - 1
if half_exp is int_oo:
return int_oo
# TODO: microoptimization is to avoid overflowing into arbitrary precision
# and detect overflow prior to doing operations
result = half_exp * half_exp
if result > sys.maxsize - 1:
return sys.maxsize - 1
if result > sys.maxsize:
return int_oo
if exponent % 2 == 1:
result *= base
if result > sys.maxsize - 1:
return sys.maxsize - 1
if result > sys.maxsize:
return int_oo
return result
@ -447,14 +479,20 @@ class PowByNatural(sympy.Function):
@classmethod
def eval(cls, base, exp):
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
return sympy.Integer(safe_pow(base, exp))
if isinstance(exp, sympy.Integer):
# Translate power into iterated multiplication
r = sympy.Integer(1)
for _ in range(int(exp)):
r *= base
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
r = safe_pow(base, exp)
if r in (-int_oo, int_oo):
return r
return sympy.Integer(r)
if isinstance(exp, sympy.Integer):
# Rely on regular sympy Pow for this (note that iterated
# multiplication turns into a Pow anyway, you can't escape!!)
return sympy.Pow(base, exp)
if exp in (int_oo, sympy.oo):
if base.is_nonnegative:
return int_oo
elif base.is_negative:
return sympy.zoo # this is apparently what (-2)**sympy.oo does
# NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
# is a natural number if we do
@ -467,6 +505,11 @@ class FloatPow(sympy.Function):
@classmethod
def eval(cls, base, exp):
# NB: These test sympy.Number, not sympy.Float, because:
# - Sometimes we may have sympy.oo or int_oo, and that's not a Float
# (but coerces to math.Inf)
# - Sometimes Float(0.0) will unpredictably decay to Integer(0),
# but we should still accept it in floatey contexts
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
return sympy.Float(float(base) ** float(exp))
# NB: do not do any nontrivial reasoning
@ -510,7 +553,18 @@ class IntTrueDiv(sympy.Function):
if divisor.is_zero:
raise ZeroDivisionError("division by zero")
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
if (
isinstance(base, sympy.Number)
and isinstance(divisor, sympy.Number)
and (
base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
)
):
# Don't have to worry about precision here, you're getting zero or
# inf from the division
return sympy.Float(float(base) / float(divisor))
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return sympy.Float(int(base) / int(divisor))
@ -567,10 +621,10 @@ class TruncToInt(sympy.Function):
@classmethod
def eval(cls, number):
# assert number.is_integer is not True, number
if number == sympy.oo:
return sympy.Integer(sys.maxsize - 1)
if number == -sympy.oo:
return sympy.Integer(-sys.maxsize - 1)
if number in (sympy.oo, int_oo):
return int_oo
if number in (-sympy.oo, -int_oo):
return -int_oo
if isinstance(number, sympy.Number):
return sympy.Integer(math.trunc(float(number)))
@ -583,7 +637,11 @@ class RoundToInt(sympy.Function):
def eval(cls, number):
# assert number.is_integer is not True, number
if isinstance(number, sympy.Float):
if number is sympy.oo:
return int_oo
if number is -sympy.oo:
return -int_oo
if isinstance(number, sympy.Number):
return sympy.Integer(round(float(number), 0))
@ -610,7 +668,7 @@ class RoundDecimal(sympy.Function):
def eval(cls, number, ndigits):
# assert number.is_integer is not True, number
if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer):
if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
return sympy.Float(round(float(number), int(ndigits)))
@ -625,6 +683,10 @@ class ToFloat(sympy.Function):
if isinstance(number, sympy.Integer):
return sympy.Float(int(number))
if number is int_oo:
return sympy.oo
if number is -int_oo:
return -sympy.oo
def make_opaque_unary_fn(name):
@ -655,7 +717,11 @@ def make_opaque_unary_fn(name):
# weird objects but ask silly questions, get silly answers
except OverflowError:
return getattr(sympy, name)(a)
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo]:
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
if a is int_oo:
a = sympy.oo
if a is -int_oo:
a = -sympy.oo
return getattr(sympy, name)(a)
return None

View File

@ -9,6 +9,7 @@ of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
"""
import functools
import logging
from typing import Any, Dict, Union
import sympy
@ -37,6 +38,9 @@ from .functions import (
)
log = logging.getLogger(__name__)
# TODO: Dedupe this with SYMPY_INTERP
@ -157,11 +161,18 @@ def sympy_interp(
else:
handler_name = handlers()[expr.func]
handler = getattr(analysis, handler_name)
try:
if handler_name in ASSOCIATIVE_OPS:
assert len(args) > 1
acc = handler(args[0], args[1])
for i in range(2, len(args)):
acc = handler(acc, args[i])
log.debug("%s(%s) -> %s", handler_name, args, acc)
return acc
else:
return handler(*args)
r = handler(*args)
log.debug("%s(%s) -> %s", handler_name, args, r)
return r
except Exception:
log.warning("failed while executing %s(%s)", handler_name, args)
raise

View File

@ -0,0 +1,394 @@
import mpmath.libmp as mlib # type: ignore[import-untyped]
import sympy
from sympy import Expr
from sympy.core.decorators import _sympifyit
from sympy.core.expr import AtomicExpr
from sympy.core.numbers import Number
from sympy.core.parameters import global_parameters
from sympy.core.singleton import S, Singleton
class IntInfinity(Number, metaclass=Singleton):
r"""Positive integer infinite quantity.
Integer infinity is a value in an extended integers which
is greater than all other integers. We distinguish it from
sympy's existing notion of infinity in that it reports that
it is_integer.
Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
or can be imported as ``int_oo``.
"""
# NB: We can't actually mark this as infinite, as integer and infinite are
# inconsistent assumptions in sympy. We also report that we are complex,
# different from sympy.oo
is_integer = True
is_commutative = True
is_number = True
is_extended_real = True
is_comparable = True
is_extended_positive = True
is_prime = False
# Ensure we get dispatched to before plain numbers
_op_priority = 100.0
__slots__ = ()
def __new__(cls):
return AtomicExpr.__new__(cls)
def _sympystr(self, printer):
return "int_oo"
def _eval_subs(self, old, new):
if self == old:
return new
# We could do these, not sure about it
"""
def _eval_evalf(self, prec=None):
return Float('inf')
def evalf(self, prec=None, **options):
return self._eval_evalf(prec)
"""
@_sympifyit("other", NotImplemented)
def __add__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.NegativeInfinity:
return S.NegativeInfinity
if other in (S.NegativeIntInfinity, S.NaN):
return S.NaN
return self
return Number.__add__(self, other)
__radd__ = __add__
@_sympifyit("other", NotImplemented)
def __sub__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.Infinity:
return S.NegativeInfinity
if other in (S.IntInfinity, S.NaN):
return S.NaN
return self
return Number.__sub__(self, other)
@_sympifyit("other", NotImplemented)
def __rsub__(self, other):
return (-self).__add__(other)
@_sympifyit("other", NotImplemented)
def __mul__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other.is_zero or other is S.NaN:
return S.NaN
if other.is_extended_positive:
return self
return S.NegativeIntInfinity
return Number.__mul__(self, other)
__rmul__ = __mul__
@_sympifyit("other", NotImplemented)
def __truediv__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (
S.Infinity,
S.IntInfinity,
S.NegativeInfinity,
S.NegativeIntInfinity,
S.NaN,
):
return S.NaN
if other.is_extended_nonnegative:
return S.Infinity # truediv produces float
return S.NegativeInfinity # truediv produces float
return Number.__truediv__(self, other)
def __abs__(self):
return S.IntInfinity
def __neg__(self):
return S.NegativeIntInfinity
def _eval_power(self, expt):
if expt.is_extended_positive:
return S.IntInfinity
if expt.is_extended_negative:
return S.Zero
if expt is S.NaN:
return S.NaN
if expt is S.ComplexInfinity:
return S.NaN
if expt.is_extended_real is False and expt.is_number:
from sympy.functions.elementary.complexes import re
expt_real = re(expt)
if expt_real.is_positive:
return S.ComplexInfinity
if expt_real.is_negative:
return S.Zero
if expt_real.is_zero:
return S.NaN
return self ** expt.evalf()
def _as_mpf_val(self, prec):
return mlib.finf
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
return other is S.IntInfinity
def __ne__(self, other):
return other is not S.IntInfinity
def __gt__(self, other):
if other is S.Infinity:
return sympy.false # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.true
def __ge__(self, other):
if other is S.Infinity:
return sympy.false # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.true
def __lt__(self, other):
if other is S.Infinity:
return sympy.true # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.false
def __le__(self, other):
if other is S.Infinity:
return sympy.true # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.false
@_sympifyit("other", NotImplemented)
def __mod__(self, other):
if not isinstance(other, Expr):
return NotImplemented
return S.NaN
__rmod__ = __mod__
def floor(self):
return self
def ceiling(self):
return self
int_oo = S.IntInfinity
class NegativeIntInfinity(Number, metaclass=Singleton):
"""Negative integer infinite quantity.
NegativeInfinity is a singleton, and can be accessed
by ``S.NegativeInfinity``.
See Also
========
IntInfinity
"""
# Ensure we get dispatched to before plain numbers
_op_priority = 100.0
is_integer = True
is_extended_real = True
is_commutative = True
is_comparable = True
is_extended_negative = True
is_number = True
is_prime = False
__slots__ = ()
def __new__(cls):
return AtomicExpr.__new__(cls)
def _eval_subs(self, old, new):
if self == old:
return new
def _sympystr(self, printer):
return "-int_oo"
"""
def _eval_evalf(self, prec=None):
return Float('-inf')
def evalf(self, prec=None, **options):
return self._eval_evalf(prec)
"""
@_sympifyit("other", NotImplemented)
def __add__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.Infinity:
return S.Infinity
if other in (S.IntInfinity, S.NaN):
return S.NaN
return self
return Number.__add__(self, other)
__radd__ = __add__
@_sympifyit("other", NotImplemented)
def __sub__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.NegativeInfinity:
return S.Infinity
if other in (S.NegativeIntInfinity, S.NaN):
return S.NaN
return self
return Number.__sub__(self, other)
@_sympifyit("other", NotImplemented)
def __rsub__(self, other):
return (-self).__add__(other)
@_sympifyit("other", NotImplemented)
def __mul__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other.is_zero or other is S.NaN:
return S.NaN
if other.is_extended_positive:
return self
return S.IntInfinity
return Number.__mul__(self, other)
__rmul__ = __mul__
@_sympifyit("other", NotImplemented)
def __truediv__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (
S.Infinity,
S.IntInfinity,
S.NegativeInfinity,
S.NegativeIntInfinity,
S.NaN,
):
return S.NaN
if other.is_extended_nonnegative:
return self
return S.Infinity # truediv returns float
return Number.__truediv__(self, other)
def __abs__(self):
return S.IntInfinity
def __neg__(self):
return S.IntInfinity
def _eval_power(self, expt):
if expt.is_number:
if expt in (
S.NaN,
S.Infinity,
S.NegativeInfinity,
S.IntInfinity,
S.NegativeIntInfinity,
):
return S.NaN
if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
if expt.is_odd:
return S.NegativeIntInfinity
else:
return S.IntInfinity
inf_part = S.IntInfinity**expt
s_part = S.NegativeOne**expt
if inf_part == 0 and s_part.is_finite:
return inf_part
if (
inf_part is S.ComplexInfinity
and s_part.is_finite
and not s_part.is_zero
):
return S.ComplexInfinity
return s_part * inf_part
def _as_mpf_val(self, prec):
return mlib.fninf
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
return other is S.NegativeIntInfinity
def __ne__(self, other):
return other is not S.NegativeIntInfinity
def __gt__(self, other):
if other is S.NegativeInfinity:
return sympy.true # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.false
def __ge__(self, other):
if other is S.NegativeInfinity:
return sympy.true # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.false
def __lt__(self, other):
if other is S.NegativeInfinity:
return sympy.false # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.true
def __le__(self, other):
if other is S.NegativeInfinity:
return sympy.false # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.true
@_sympifyit("other", NotImplemented)
def __mod__(self, other):
if not isinstance(other, Expr):
return NotImplemented
return S.NaN
__rmod__ = __mod__
def floor(self):
return self
def ceiling(self):
return self
def as_powers_dict(self):
return {S.NegativeOne: 1, S.IntInfinity: 1}

View File

@ -6,7 +6,6 @@ import itertools
import logging
import math
import operator
import sys
from typing import (
Callable,
Dict,
@ -24,6 +23,7 @@ import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch
from torch._logging import LazyString
from torch._prims_common import dtype_to_type
from .functions import (
@ -43,6 +43,7 @@ from .functions import (
TruncToInt,
)
from .interp import sympy_interp
from .numbers import int_oo, IntInfinity, NegativeIntInfinity
log = logging.getLogger(__name__)
@ -168,7 +169,10 @@ class ValueRanges(Generic[_T]):
self,
"is_int",
not self.is_bool
and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)),
and (
isinstance(lower, (sympy.Integer, NegativeIntInfinity))
or isinstance(upper, (sympy.Integer, IntInfinity))
),
)
"""
# This assert is just impossible right now, too many sympy bugs
@ -265,11 +269,14 @@ class ValueRanges(Generic[_T]):
def is_singleton(self) -> bool:
return self.lower == self.upper
# TODO: this doesn't work with bools but arguably it should
@staticmethod
def unknown() -> ValueRanges[sympy.Expr]:
return ValueRanges(-sympy.oo, sympy.oo)
@staticmethod
def unknown_int() -> ValueRanges[sympy.Expr]:
return ValueRanges(-int_oo, int_oo)
@staticmethod
def unknown_bool() -> ValueRanges[SympyBoolean]:
return ValueRanges(sympy.false, sympy.true)
@ -401,7 +408,7 @@ class SymPyValueRangeAnalysis:
elif dtype.is_floating_point:
return ValueRanges.unknown()
else:
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
return ValueRanges(-int_oo, int_oo)
if is_python:
type_ = dtype_to_type(dtype)
@ -424,6 +431,10 @@ class SymPyValueRangeAnalysis:
def to_dtype(a, dtype, src_dtype=None):
if dtype == torch.float64:
return ValueRanges.increasing_map(a, ToFloat)
elif dtype == torch.bool:
return ValueRanges.unknown_bool()
elif not dtype.is_floating_point:
return ValueRanges.unknown_int()
return ValueRanges.unknown()
@staticmethod
@ -515,9 +526,7 @@ class SymPyValueRangeAnalysis:
def int_truediv(a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if 0 in b or (
(-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
):
if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)):
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(
@ -541,14 +550,17 @@ class SymPyValueRangeAnalysis:
def floordiv(a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if 0 in b or (
# TODO: make this more precise
(-sympy.oo in a or sympy.oo in a)
or (-sympy.oo in b or sympy.oo in b)
):
if 0 in b:
return ValueRanges.unknown()
products = []
for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]):
r = FloorDiv(x, y)
if r is sympy.nan:
products.append((sympy.sign(x) * sympy.sign(y)) * int_oo)
else:
return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv)
products.append(r)
return ValueRanges(min(products), max(products))
@classmethod
def mod(cls, x, y):
@ -564,10 +576,10 @@ class SymPyValueRangeAnalysis:
def c_div(a, b):
x = a / b
return sympy.Integer(x) if x.is_finite else x
return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x
if 0 in y:
return ValueRanges.unknown()
return ValueRanges.unknown_int()
elif y.is_singleton():
y_val = abs(y.lower)
# If it wraps, we need to take the whole interval
@ -597,7 +609,7 @@ class SymPyValueRangeAnalysis:
@classmethod
def is_non_overlapping_and_dense_indicator(cls, *args):
return ValueRanges.unknown() # TODO: type here is wrong
return ValueRanges.unknown_int()
@classmethod
def pow_by_natural(cls, a, b):
@ -611,7 +623,7 @@ class SymPyValueRangeAnalysis:
# to replacements, so don't assert it, but DO clamp it to prevent
# degenerate problems
return ValueRanges.coordinatewise_increasing_map(
a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural
a, b & ValueRanges(0, int_oo), PowByNatural
)
elif b.is_singleton():
if b.lower % 2 == 0:
@ -939,6 +951,8 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
if dtype.is_floating_point:
return sympy.Float(x)
else:
if x in (int_oo, -int_oo):
return x
try:
return sympy.Integer(x)
except TypeError:
@ -986,7 +1000,18 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def bound_sympy(
expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None
) -> ValueRanges:
log.debug("bound_sympy(%s, %s)", expr, ranges)
log.debug(
"bound_sympy(%s)%s",
expr,
LazyString(
lambda: "\n"
+ "\n".join(
f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
)
if ranges
else ""
),
)
if isinstance(expr, sympy.Number):
return ValueRanges.wrap(expr)