Introduce int_oo (#127693)

In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range.

After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better.

But especially after #126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation.

The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments.

Fixes https://github.com/pytorch/pytorch/issues/127396

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127693
Approved by: https://github.com/lezcano
ghstack dependencies: #126905
This commit is contained in:
Edward Z. Yang 2024-06-09 09:48:47 -07:00 committed by PyTorch MergeBot
parent db2fa7b827
commit 9cab5987bd
19 changed files with 746 additions and 145 deletions

View File

@ -253,7 +253,6 @@ Target Expressions:
==> (>= 0 s1) ==> (>= 0 s1)
==> (>= 0 s2) ==> (>= 0 s2)
==> (>= 0 s3) ==> (>= 0 s3)
==> (>= 9223372036854775806 s0)
Failed Source Expressions: Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
@ -287,14 +286,14 @@ Failure occurred while running node:
Model: Model:
==> L['shape'][0]: 1 ==> L['shape'][0]: 1
==> L['shape'][1]: 1 ==> L['shape'][1]: 1
==> L['shape'][2]: 2 ==> L['shape'][2]: 0
==> L['x'].size()[0]: 3 ==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0 ==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1 ==> L['x'].stride()[0]: 1
==> s0: 3 ==> s0: 3
==> s1: 1 ==> s1: 1
==> s2: 1 ==> s2: 1
==> s3: 2 ==> s3: 0
Assertions: Assertions:
==> (== 0 L['x'].storage_offset()) ==> (== 0 L['x'].storage_offset())
@ -318,10 +317,6 @@ Target Expressions:
==> (== L['shape'][2] s3) ==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0) ==> (== L['x'].size()[0] s0)
==> (> s0 0) ==> (> s0 0)
==> (>= 9223372036854775806 s0)
==> (>= 9223372036854775807 s1)
==> (>= 9223372036854775807 s2)
==> (>= 9223372036854775807 s3)
Failed Source Expressions: Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",

View File

@ -3473,7 +3473,6 @@ class GraphModule(torch.nn.Module):
] ]
false_guard_code = [ false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
] ]
test_symbool_guards( test_symbool_guards(
f, f,

View File

@ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match:
> Left: {0: 0, 1: 1, 2: s1, 3: s0} > Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Right: {0: 0, 1: 1} > Right: {0: 0, 1: 1}
==> var_to_range: values don't match. ==> var_to_range: values don't match.
> Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} > Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Right: {} > Right: {}
==> var_to_sources: values don't match. ==> var_to_sources: values don't match.
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]} > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
@ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match:
> Left: 2 > Left: 2
> Right: 0 > Right: 0
==> var_to_range: values don't match. ==> var_to_range: values don't match.
> Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]} > Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Right: {} > Right: {}
""", """,
) )
@ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match:
> Left: {s0: 3} > Left: {s0: 3}
> Right: {} > Right: {}
==> var_to_range: values don't match. ==> var_to_range: values don't match.
> Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]} > Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
""", """,
) )
self._replay_and_check(main) self._replay_and_check(main)
@ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match:
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match. ==> var_to_range: values don't match.
> Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]} > Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
""", """,
) )
self._replay_and_check(main) self._replay_and_check(main)

View File

@ -201,6 +201,19 @@ class TestDynamismExpression(TestCase):
dynamic_shapes={"x": {0: dim_x}}, dynamic_shapes={"x": {0: dim_x}},
) )
def test_export_slice_maxsize(self):
class Slice(torch.nn.Module):
def forward(self, *args):
return torch.ops.aten.slice.Tensor(*args)
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
torch.export.export(
Slice(),
inp,
dynamic_shapes=dynamic_shapes,
)
def test_export_constraints_error(self): def test_export_constraints_error(self):
class ConflictingConstraints(torch.nn.Module): class ConflictingConstraints(torch.nn.Module):
def forward(self, x): def forward(self, x):
@ -5183,7 +5196,7 @@ def forward(self, x, y):
} }
export(f, (inputs,), dynamic_shapes=dynamic_shapes) export(f, (inputs,), dynamic_shapes=dynamic_shapes)
def test_disable_forced_specializations(self): def test_disable_forced_specializations_ok(self):
# check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags # check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags
# both behave correctly, avoiding forced specializations and deferring to runtime. # both behave correctly, avoiding forced specializations and deferring to runtime.
# case 1: modulo guards # case 1: modulo guards

View File

@ -633,10 +633,6 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
func, (torch.randn(3, 4),) func, (torch.randn(3, 4),)
) )
@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}.",
reason="https://github.com/pytorch/pytorch/issues/112622",
)
def test_operator_with_scalar_output(self): def test_operator_with_scalar_output(self):
class Foo(torch.nn.Module): class Foo(torch.nn.Module):
def forward(self, x, y): def forward(self, x, y):

View File

@ -381,6 +381,17 @@ class TestPySymInt(TestCase):
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
def test_floordiv_static(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
# This was extracted from
# python test/inductor/test_cuda_cpp_wrapper.py -k
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
bool(s0 % 2 == 0)
bool(s0 % (s0 // 2) == 0)
bool(2 * (s0 // 2) == s0)
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
def test_numel(self): def test_numel(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env) x = create_symbolic_tensor("x", torch.randn(5), shape_env)

View File

@ -1201,7 +1201,9 @@ def forward(self, x_1):
batch_size = 4 batch_size = 4
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size)) src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
gm = make_fx(f, tracing_mode="symbolic")(src_tokens) gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
self.assertEqual(len(gm.shape_env.guards), 0) # Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
# 1 ok)
self.assertEqual(len(gm.shape_env.guards), 1)
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test') @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_cpu_scalar_cuda(self): def test_cpu_scalar_cuda(self):

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: pt2"] # Owner(s): ["oncall: pt2"]
import itertools import itertools
import math
import sys import sys
import sympy import sympy
@ -19,6 +20,7 @@ from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
from sympy.core.relational import is_ge, is_le, is_gt, is_lt from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools import functools
import torch.fx as fx import torch.fx as fx
@ -122,6 +124,74 @@ def generate_range(vals):
yield ValueRanges(a1, a2) yield ValueRanges(a1, a2)
class TestNumbers(TestCase):
def test_int_infinity(self):
self.assertIsInstance(int_oo, IntInfinity)
self.assertIsInstance(-int_oo, NegativeIntInfinity)
self.assertTrue(int_oo.is_integer)
# is tests here are for singleton-ness, don't use it for comparisons
# against numbers
self.assertIs(int_oo + int_oo, int_oo)
self.assertIs(int_oo + 1, int_oo)
self.assertIs(int_oo - 1, int_oo)
self.assertIs(-int_oo - 1, -int_oo)
self.assertIs(-int_oo + 1, -int_oo)
self.assertIs(-int_oo + (-int_oo), -int_oo)
self.assertIs(-int_oo - int_oo, -int_oo)
self.assertIs(1 + int_oo, int_oo)
self.assertIs(1 - int_oo, -int_oo)
self.assertIs(int_oo * int_oo, int_oo)
self.assertIs(2 * int_oo, int_oo)
self.assertIs(int_oo * 2, int_oo)
self.assertIs(-1 * int_oo, -int_oo)
self.assertIs(-int_oo * int_oo, -int_oo)
self.assertIs(2 * -int_oo, -int_oo)
self.assertIs(-int_oo * 2, -int_oo)
self.assertIs(-1 * -int_oo, int_oo)
self.assertIs(int_oo / 2, sympy.oo)
self.assertIs(-(-int_oo), int_oo) # noqa: B002
self.assertIs(abs(int_oo), int_oo)
self.assertIs(abs(-int_oo), int_oo)
self.assertIs(int_oo ** 2, int_oo)
self.assertIs((-int_oo) ** 2, int_oo)
self.assertIs((-int_oo) ** 3, -int_oo)
self.assertEqual(int_oo ** -1, 0)
self.assertEqual((-int_oo) ** -1, 0)
self.assertIs(int_oo ** int_oo, int_oo)
self.assertTrue(int_oo == int_oo)
self.assertFalse(int_oo != int_oo)
self.assertTrue(-int_oo == -int_oo)
self.assertFalse(int_oo == 2)
self.assertTrue(int_oo != 2)
self.assertFalse(int_oo == sys.maxsize)
self.assertTrue(int_oo >= sys.maxsize)
self.assertTrue(int_oo >= 2)
self.assertTrue(int_oo >= -int_oo)
def test_relation(self):
self.assertIs(sympy.Add(2, int_oo), int_oo)
self.assertFalse(-int_oo > 2)
def test_lt_self(self):
self.assertFalse(int_oo < int_oo)
self.assertIs(min(-int_oo, -4), -int_oo)
self.assertIs(min(-int_oo, -int_oo), -int_oo)
def test_float_cast(self):
self.assertEqual(float(int_oo), math.inf)
self.assertEqual(float(-int_oo), -math.inf)
def test_mixed_oo_int_oo(self):
# Arbitrary choice
self.assertTrue(int_oo < sympy.oo)
self.assertFalse(int_oo > sympy.oo)
self.assertTrue(sympy.oo > int_oo)
self.assertFalse(sympy.oo < int_oo)
self.assertIs(max(int_oo, sympy.oo), sympy.oo)
self.assertTrue(-int_oo > -sympy.oo)
self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)
class TestValueRanges(TestCase): class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS) @parametrize("fn", UNARY_OPS)
@parametrize("dtype", ("int", "float")) @parametrize("dtype", ("int", "float"))

View File

@ -734,6 +734,11 @@ def slice_forward(
end: Optional[int] = None, end: Optional[int] = None,
step: int = 1, step: int = 1,
): ):
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
)
ndim = self.dim() ndim = self.dim()
if ndim == 0: if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
@ -760,7 +765,9 @@ def slice_forward(
if end_val < start_val: if end_val < start_val:
end_val = start_val end_val = start_val
elif end_val > sizes[dim]: elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
end_val > sizes[dim]
):
end_val = sizes[dim] end_val = sizes[dim]
storage_offset = self.storage_offset() + start_val * strides[dim] storage_offset = self.storage_offset() + start_val * strides[dim]

View File

@ -10,6 +10,7 @@ import sympy
import torch import torch
import torch.fx import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.fx.passes.infra.pass_base import PassBase, PassResult
@ -23,9 +24,9 @@ class InputDim(NamedTuple):
def _convert_to_int(val): def _convert_to_int(val):
# Convert simple sympy Integers into concrete int # Convert simple sympy Integers into concrete int
if val == sympy.oo: if val in (sympy.oo, int_oo):
return math.inf return math.inf
if val == -sympy.oo: if val in (-sympy.oo, -int_oo):
return -math.inf return -math.inf
if isinstance(val, sympy.Integer): if isinstance(val, sympy.Integer):
return int(val) return int(val)

View File

@ -42,6 +42,7 @@ from torch.fx.experimental import symbolic_shapes
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
from torch.utils._pytree import treespec_dumps, treespec_loads from torch.utils._pytree import treespec_dumps, treespec_loads
from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from .schema import ( # type: ignore[attr-defined] from .schema import ( # type: ignore[attr-defined]
Argument, Argument,
@ -321,9 +322,9 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]
def _sympy_int_to_int(val: sympy.Expr, adjust: str): def _sympy_int_to_int(val: sympy.Expr, adjust: str):
# Convert simple sympy Integers into concrete int # Convert simple sympy Integers into concrete int
if val == sympy.oo: if val in (sympy.oo, int_oo):
return math.inf return math.inf
if val == -sympy.oo: if val in (-sympy.oo, -int_oo):
return -math.inf return -math.inf
if isinstance(val, sympy.Integer): if isinstance(val, sympy.Integer):
return int(val) return int(val)
@ -346,9 +347,9 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str):
def _int_to_sympy_int(val) -> sympy.Expr: def _int_to_sympy_int(val) -> sympy.Expr:
# Convert concrete int into simple sympy Integers # Convert concrete int into simple sympy Integers
if val == math.inf: if val == math.inf:
return sympy.oo return int_oo
if val == -math.inf: if val == -math.inf:
return -sympy.oo return -int_oo
return sympy.Integer(val) return sympy.Integer(val)
@ -1826,7 +1827,7 @@ class GraphModuleDeserializer(metaclass=Final):
self.symbol_name_to_range = {} self.symbol_name_to_range = {}
if symbol_name_to_range: if symbol_name_to_range:
for k, vr in symbol_name_to_range.items(): for k, vr in symbol_name_to_range.items():
lower = int(vr.lower) lower = vr.lower
if vr.upper >= 2: # max is >= 2, not sym bool range if vr.upper >= 2: # max is >= 2, not sym bool range
lower = max(2, lower) lower = max(2, lower)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper) self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)

View File

@ -42,6 +42,7 @@ from torch.fx.experimental.symbolic_shapes import (
SymTypes, SymTypes,
) )
from torch.utils._mode_utils import no_dispatch from torch.utils._mode_utils import no_dispatch
from torch.utils._sympy.numbers import int_oo
from . import config, ir from . import config, ir
from .codegen.common import ( from .codegen.common import (
@ -1427,18 +1428,21 @@ class GraphLowering(torch.fx.Interpreter):
vr = shape_env.var_to_range[i0] vr = shape_env.var_to_range[i0]
if not shape_env._default_unspecified_value_range().issubset(vr): if not shape_env._default_unspecified_value_range().issubset(vr):
def convert(s): def is_convertible(s):
if s in (int_oo, -int_oo):
return False
try: try:
return int(s) int(s)
return True
except TypeError: except TypeError:
return None return False
if (lower := convert(vr.lower)) is not None: if is_convertible(vr.lower):
self.register_buffer( self.register_buffer(
ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"), ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"),
set_name=True, set_name=True,
) )
if (upper := convert(vr.upper)) is not None: if is_convertible(vr.upper):
self.register_buffer( self.register_buffer(
ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"), ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"),
set_name=True, set_name=True,

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import builtins
import dataclasses import dataclasses
import inspect import inspect
import sys import sys
@ -41,9 +40,11 @@ class _Dim(type):
@staticmethod @staticmethod
def readable(name, min_, max_): def readable(name, min_, max_):
from torch.utils._sympy.numbers import int_oo
if min_ == 2: if min_ == 2:
min_ = None min_ = None
if max_ == sys.maxsize - 1: if max_ == int_oo:
max_ = None max_ = None
if min_ is None and max_ is None: if min_ is None and max_ is None:
return f"Dim('{name}')" return f"Dim('{name}')"
@ -140,6 +141,11 @@ class _DerivedDim(_Dim):
# TODO(avik): use sympy value range analysis instead? # TODO(avik): use sympy value range analysis instead?
from sympy import Integer from sympy import Integer
from torch.utils._sympy.numbers import int_oo
if self.root.min is -int_oo: # type: ignore[attr-defined]
return -int_oo # fn not needed cuz increasing
_min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined]
root = self.root # type: ignore[attr-defined] root = self.root # type: ignore[attr-defined]
assert _min_symint >= 0, ( assert _min_symint >= 0, (
@ -155,6 +161,11 @@ class _DerivedDim(_Dim):
# TODO(avik): use sympy value range analysis instead? # TODO(avik): use sympy value range analysis instead?
from sympy import Integer from sympy import Integer
from torch.utils._sympy.numbers import int_oo
if self.root.max is int_oo: # type: ignore[attr-defined]
return int_oo # fn not needed cuz increasing
_max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined]
root = self.root # type: ignore[attr-defined] root = self.root # type: ignore[attr-defined]
assert _max_symint <= sys.maxsize - 1, ( assert _max_symint <= sys.maxsize - 1, (
@ -190,8 +201,10 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
Returns: Returns:
A type that can be used in dynamic shape specifications for tensors. A type that can be used in dynamic shape specifications for tensors.
""" """
from torch.utils._sympy.numbers import int_oo
_min = 0 if min is None else min _min = 0 if min is None else min
_max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1) _max = int_oo if max is None else max
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
dim = _Dim(name, (int,), {"min": _min, "max": _max}) dim = _Dim(name, (int,), {"min": _min, "max": _max})
dim.__module__ = getattr( dim.__module__ = getattr(
@ -269,10 +282,11 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
def _clone_with_range(self, lower=0, upper=None): def _clone_with_range(self, lower=0, upper=None):
# Import sympy locally # Import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._sympy.value_ranges import ValueRanges
if upper is None: if upper is None:
upper = sys.maxsize - 1 upper = int_oo
constraint_range = StrictMinMaxConstraint( constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
@ -503,15 +517,14 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
# Import sympy locally # Import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._sympy.value_ranges import ValueRanges
return _create_constraint( return _create_constraint(
weakref.ref(t), weakref.ref(t),
id(t), id(t),
index, index,
StrictMinMaxConstraint( StrictMinMaxConstraint(vr=ValueRanges(lower=0, upper=int_oo), warn_only=False),
vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False
),
debug_name=debug_name, debug_name=debug_name,
) )
@ -725,6 +738,7 @@ def _process_dynamic_shapes(
import sympy import sympy
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.solve import try_solve from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._sympy.value_ranges import ValueRanges
@ -799,7 +813,7 @@ def _process_dynamic_shapes(
constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) constraint = dynamic_dim(tensor, i, debug_name=dim.__name__)
if dim.min != 0: if dim.min != 0:
constraint = constraint >= dim.min constraint = constraint >= dim.min
if dim.max != sys.maxsize - 1: if dim.max != int_oo:
constraint = constraint <= dim.max constraint = constraint <= dim.max
return constraint return constraint

View File

@ -65,6 +65,7 @@ from torch.utils._sympy.functions import (
FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt
) )
from torch.utils._sympy.solve import try_solve from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._traceback import format_frame, CapturedTraceback from torch.utils._traceback import format_frame, CapturedTraceback
@ -871,9 +872,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
for N=1. for N=1.
""" """
if min is None: if min is None:
min = -sys.maxsize - 1 min = -int_oo
if max is None: if max is None:
max = sys.maxsize - 1 max = int_oo
if max < min: if max < min:
raise ValueError( raise ValueError(
@ -1382,6 +1383,7 @@ SYMPY_INTERP = {
'PythonMod': operator.mod, 'PythonMod': operator.mod,
'FloorDiv': operator.floordiv, 'FloorDiv': operator.floordiv,
'TrueDiv': operator.truediv, 'TrueDiv': operator.truediv,
'PowByNatural': operator.pow,
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
'floor': math.floor, 'floor': math.floor,
'ceiling': math.ceil, 'ceiling': math.ceil,
@ -1994,7 +1996,7 @@ class DimConstraints:
(dim.min < 2 and c.get("min", 2) == 2) (dim.min < 2 and c.get("min", 2) == 2)
or dim.min == c.get("min", 2) or dim.min == c.get("min", 2)
) # let pass if analysis min = 2 and specified min = 0/1 ) # let pass if analysis min = 2 and specified min = 0/1
and dim.max == c.get("max", sys.maxsize - 1) and dim.max == c.get("max", int_oo)
) )
# 1) newly introduced roots # 1) newly introduced roots
@ -2017,7 +2019,7 @@ class DimConstraints:
modulus, remainder = sympy.polys.polytools.div(c["eq"], root) modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
c_min = c.get("min", 2) c_min = c.get("min", 2)
min_ = math.ceil((c_min - remainder) / modulus) min_ = math.ceil((c_min - remainder) / modulus)
c_max = c.get("max", sys.maxsize - 1) c_max = c.get("max", int_oo)
max_ = math.floor((c_max - remainder) / modulus) max_ = math.floor((c_max - remainder) / modulus)
# create result & dim # create result & dim
results[str(root)] = {"min": min_, "max": max_} results[str(root)] = {"min": min_, "max": max_}
@ -2765,7 +2767,7 @@ class ShapeEnv:
if min is None: if min is None:
min = 0 min = 0
if max is None: if max is None:
max = sys.maxsize - 1 max = int_oo
if max < min: if max < min:
raise ValueError( raise ValueError(
@ -4094,7 +4096,7 @@ class ShapeEnv:
assert sources assert sources
bounds = [] bounds = []
if r.lower != -sympy.oo: if r.lower not in (-sympy.oo, -int_oo):
if any(is_dim(source) for source in sources): if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Ge(symbol, r.lower)) self.dim_constraints.add(sympy.Ge(symbol, r.lower))
# Only print lower bound in simplified mode if it is not the # Only print lower bound in simplified mode if it is not the
@ -4102,14 +4104,7 @@ class ShapeEnv:
if not _simplified or r.lower != self._default_value_range().lower: if not _simplified or r.lower != self._default_value_range().lower:
bounds.append(str(r.lower)) bounds.append(str(r.lower))
bounds.append(source_ref(sources[0])) bounds.append(source_ref(sources[0]))
# NB: This looks like an off-by-one error but it's not: the if r.upper not in (sympy.oo, int_oo):
# upper bound may be sys.maxsize - 1 because we intentionally
# exclude sys.maxsize from our bounds to deal with direct
# == INT_MAX guards, but it's still dumb to actually test it.
# Note that you can be off by a pretty large constant and it
# won't matter because sizes in practice will be no where near
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
if any(is_dim(source) for source in sources): if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Le(symbol, r.upper)) self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting # nontrivial upper bound is always interesting
@ -4121,9 +4116,8 @@ class ShapeEnv:
constraints = symbol_to_constraints[symbol] constraints = symbol_to_constraints[symbol]
for c in constraints: for c in constraints:
if isinstance(c, StrictMinMaxConstraint): if isinstance(c, StrictMinMaxConstraint):
# NB: By default, we have a restrictive range # TODO: With int_oo, I think this condition is a noop
# 2 <= s0 <= sys.maxsize - 1. But export users generally # now
# expect to be able to specify nice ranges like [0, oo]
if not (c.vr & self._default_value_range()).issubset(r): if not (c.vr & self._default_value_range()).issubset(r):
source = sources[0] source = sources[0]
@ -4196,9 +4190,9 @@ class ShapeEnv:
# Reason: '_maybe_evaluate_static' may eliminate guards based on the # Reason: '_maybe_evaluate_static' may eliminate guards based on the
# refined value ranges. # refined value ranges.
for sym, vr in self.var_to_range.items(): for sym, vr in self.var_to_range.items():
if vr.lower != -sympy.oo: if vr.lower not in (-sympy.oo, -int_oo):
self._add_target_expr(sympy.Le(vr.lower, sym)) self._add_target_expr(sympy.Le(vr.lower, sym))
if vr.upper != sympy.oo: if vr.upper not in (sympy.oo, int_oo):
self._add_target_expr(sympy.Le(sym, vr.upper)) self._add_target_expr(sympy.Le(sym, vr.upper))
# Before validating, populate the input of the validator with the # Before validating, populate the input of the validator with the
@ -4330,9 +4324,14 @@ class ShapeEnv:
var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
if size_oblivious: if size_oblivious:
# Clamp values of size-like variables # Clamp values of size-like variables
# NB: discarding the old upper bound in intentional, per
# https://github.com/pytorch/pytorch/pull/123675
for x in self.size_like & var_to_range.keys(): for x in self.size_like & var_to_range.keys():
if var_to_range[x] is not None: if var_to_range[x] is not None:
var_to_range[x] = ValueRanges(2, sys.maxsize - 1) # NB: do NOT set upper to 2 ** 48, we're using this solely
# to determine if we can do size-like replacement, the
# upper bound is irrelevant here
var_to_range[x] = ValueRanges(2, int_oo)
assert var_to_range[x].is_int assert var_to_range[x].is_int
return bound_sympy(expr, var_to_range) return bound_sympy(expr, var_to_range)
@ -4450,18 +4449,25 @@ class ShapeEnv:
vr = self._default_unspecified_value_range() vr = self._default_unspecified_value_range()
if size_oblivious and k in self.size_like: if size_oblivious and k in self.size_like:
lower = max(2, vr.lower) lower = max(2, vr.lower)
# Clamping size-oblivious to some quantity below sys.maxsize
# helps us determine that f(u0) != sys.maxsize, which is a
# test that is looking for sys.maxsize as a sentinel, but you
# don't really want to worry about it for unbacked SymInts.
# This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way.
upper = min(2 ** 48, vr.upper)
# This is a bit dodgy: what this means is that there was a # This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This # size-like unbacked symbol whose upper bound < 2. This
# causes... problems. # causes... problems.
if lower <= vr.upper: if lower <= upper:
vr = ValueRanges(lower, vr.upper) vr = ValueRanges(lower, upper)
else: else:
lower = vr.lower lower = vr.lower
# Don't do anything if we don't have a nontrivial lower bound # Don't do anything if we don't have a nontrivial lower bound
# Also don't do anything if we asked only to simplify unbacked # Also don't do anything if we asked only to simplify unbacked
# SymInt # SymInt
if ( if (
lower < (-sys.maxsize - 1) // 2 or lower is -int_oo or
(unbacked_only and k in self.var_to_val) or (unbacked_only and k in self.var_to_val) or
not vr.is_int not vr.is_int
): ):
@ -4717,21 +4723,6 @@ class ShapeEnv:
if a in self.var_to_range: if a in self.var_to_range:
src_bound = self.var_to_range[a] src_bound = self.var_to_range[a]
# If you have x in [2, maxint], then 2*x in [4, 2*maxint].
# But we don't really care that the max bound says we can
# go beyond the maximum integer size, because we aren't
# using bigints anyway. Arguably, ValueRanges should know
# to do this truncation automaticaly (to avoid doing
# bigint compute in range analysis), but right now it doesn't
# so we need to get rid of some unnecessary precision.
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
def issubset(x, y):
if x.is_int and y.is_int:
return (x & int_range).issubset(y & int_range)
else:
return x.issubset(y)
# First, refine the value range of a based on the computed value range # First, refine the value range of a based on the computed value range
# of tgt. This is always OK to do, even if we decide not to do the # of tgt. This is always OK to do, even if we decide not to do the
# substitution in the end. This might be a no-op, if a already has # substitution in the end. This might be a no-op, if a already has
@ -4744,7 +4735,7 @@ class ShapeEnv:
# - the source bound non-trivially improves over what we get out of # - the source bound non-trivially improves over what we get out of
# the existing bounds. # the existing bounds.
# - the replacement is univariate and we can invert the tgt expression # - the replacement is univariate and we can invert the tgt expression
if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1: if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
b = next(iter(tgt.free_symbols)) b = next(iter(tgt.free_symbols))
# Try to invert the equality # Try to invert the equality
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
@ -4759,7 +4750,7 @@ class ShapeEnv:
b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper))
self._update_var_to_range(b, b_bound) self._update_var_to_range(b, b_bound)
tgt_bound = self.bound_sympy(tgt) tgt_bound = self.bound_sympy(tgt)
assert issubset(tgt_bound, src_bound) assert tgt_bound.issubset(src_bound)
# TODO: Should we propagate size-like-ness? # TODO: Should we propagate size-like-ness?
# #
@ -4797,13 +4788,13 @@ class ShapeEnv:
# - If the variable is unbacked, only substitute if the substitution # - If the variable is unbacked, only substitute if the substitution
# would preserve the bounds also under size-like-ness conditions. # would preserve the bounds also under size-like-ness conditions.
if not issubset(tgt_bound, src_bound): if not tgt_bound.issubset(src_bound):
self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound) self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound)
return return
elif a in self.size_like: elif a in self.size_like:
tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
src_bound_so = self.bound_sympy(a, size_oblivious=True) src_bound_so = self.bound_sympy(a, size_oblivious=True)
if not issubset(tgt_bound_so, src_bound_so): if not tgt_bound_so.issubset(src_bound_so):
self.log.debug("skipped set_replacement %s = %s (%s) " self.log.debug("skipped set_replacement %s = %s (%s) "
"[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so)
return return
@ -4888,6 +4879,7 @@ class ShapeEnv:
has_only_ephemeral_sources = ( has_only_ephemeral_sources = (
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x]) x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
) )
# NB: size_hint is int, not sympy.Expr, do not use int_oo here
size = self.size_hint(x, allow_none=True) or sys.maxsize size = self.size_hint(x, allow_none=True) or sys.maxsize
name = x.name name = x.name
# 1 puts ephemeral sourced symbols first when sorting in reverse # 1 puts ephemeral sourced symbols first when sorting in reverse
@ -4984,15 +4976,12 @@ class ShapeEnv:
return return
# See: Note - On 0/1 specialization # See: Note - On 0/1 specialization
# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
# as a sentinel sometimes. Your sizevar isn't going to be
# anywhere near the max 64-bit integer anyway.
def _default_value_range(self) -> ValueRanges: def _default_value_range(self) -> ValueRanges:
lower = 2 if self.specialize_zero_one else 0 lower = 2 if self.specialize_zero_one else 0
return ValueRanges(lower, sys.maxsize - 1) return ValueRanges(lower, int_oo)
def _default_unspecified_value_range(self) -> ValueRanges: def _default_unspecified_value_range(self) -> ValueRanges:
return ValueRanges(-sys.maxsize - 1, sys.maxsize) return ValueRanges(-int_oo, int_oo)
@_lru_cache @_lru_cache
def _simplify_floor_div(self, expr): def _simplify_floor_div(self, expr):

View File

@ -65,7 +65,7 @@ def insert_deferred_runtime_asserts(
): ):
assert len(node.args) == 1 assert len(node.args) == 1
nodes_that_already_have_sym_constraint_range.add( nodes_that_already_have_sym_constraint_range.add(
(node.args[0], node.kwargs["min"], node.kwargs["max"]) (node.args[0], node.kwargs.get("min"), node.kwargs.get("max"))
) )
if ( if (
node.op == "call_function" node.op == "call_function"
@ -86,6 +86,7 @@ def insert_deferred_runtime_asserts(
InnerTensorKey, InnerTensorKey,
) )
from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.reference import PythonReferenceAnalysis from torch.utils._sympy.reference import PythonReferenceAnalysis
# TODO: Request simplification on runtime asserts before emitting them # TODO: Request simplification on runtime asserts before emitting them
@ -367,6 +368,8 @@ def insert_deferred_runtime_asserts(
# (refinement should not be necessary once runtime # (refinement should not be necessary once runtime
# asserts cause refinement, but that's NYI) # asserts cause refinement, but that's NYI)
def convert(s): def convert(s):
if s in (int_oo, -int_oo):
return None
try: try:
return int(s) return int(s)
except TypeError: except TypeError:

View File

@ -6,6 +6,8 @@ import sys
import sympy import sympy
from sympy import S from sympy import S
from .numbers import int_oo
__all__ = [ __all__ = [
"FloorDiv", "FloorDiv",
"ModularIndexing", "ModularIndexing",
@ -101,6 +103,15 @@ class FloorDiv(sympy.Function):
# makes it difficult to check the types. # makes it difficult to check the types.
if divisor.is_zero: if divisor.is_zero:
raise ZeroDivisionError("division by zero") raise ZeroDivisionError("division by zero")
if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in (
int_oo,
-int_oo,
sympy.oo,
-sympy.oo,
):
return sympy.nan
if base is sympy.nan or divisor is sympy.nan:
return sympy.nan
if base.is_zero: if base.is_zero:
return sympy.S.Zero return sympy.S.Zero
@ -108,6 +119,23 @@ class FloorDiv(sympy.Function):
return base return base
if base.is_integer and divisor == -1: if base.is_integer and divisor == -1:
return sympy.Mul(base, -1) return sympy.Mul(base, -1)
if (
isinstance(base, sympy.Number)
and isinstance(divisor, sympy.Number)
and (
base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
)
):
r = float(base) / float(divisor)
if r == math.inf:
return int_oo
elif r == -math.inf:
return -int_oo
elif math.isnan(r):
return sympy.nan
else:
return sympy.Integer(math.floor(r))
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return sympy.Integer(int(base) // int(divisor)) return sympy.Integer(int(base) // int(divisor))
if isinstance(base, FloorDiv): if isinstance(base, FloorDiv):
@ -353,10 +381,10 @@ class CeilToInt(sympy.Function):
@classmethod @classmethod
def eval(cls, number): def eval(cls, number):
# assert number.is_integer is not True, number # assert number.is_integer is not True, number
if number == sympy.oo: if number in (sympy.oo, int_oo):
return sympy.Integer(sys.maxsize - 1) return int_oo
if number == -sympy.oo: if number in (-sympy.oo, -int_oo):
return sympy.Integer(-sys.maxsize - 1) return -int_oo
if isinstance(number, sympy.Number): if isinstance(number, sympy.Number):
return sympy.Integer(math.ceil(float(number))) return sympy.Integer(math.ceil(float(number)))
@ -367,10 +395,10 @@ class FloorToInt(sympy.Function):
@classmethod @classmethod
def eval(cls, number): def eval(cls, number):
# assert number.is_integer is not True, number # assert number.is_integer is not True, number
if number == sympy.oo: if number in (sympy.oo, int_oo):
return sympy.Integer(sys.maxsize - 1) return int_oo
if number == -sympy.oo: if number in (-sympy.oo, int_oo):
return sympy.Integer(-sys.maxsize - 1) return -int_oo
if isinstance(number, sympy.Number): if isinstance(number, sympy.Number):
return sympy.Integer(math.floor(float(number))) return sympy.Integer(math.floor(float(number)))
@ -419,6 +447,7 @@ def safe_pow(base, exp):
return sign * _safe_pow(base, exp) return sign * _safe_pow(base, exp)
# Prevent people from overflowing pow
def _safe_pow(base, exponent): def _safe_pow(base, exponent):
if exponent < 0: if exponent < 0:
raise ValueError("Exponent must be non-negative.") raise ValueError("Exponent must be non-negative.")
@ -427,17 +456,20 @@ def _safe_pow(base, exponent):
return 1 return 1
half_exp = safe_pow(base, exponent // 2) half_exp = safe_pow(base, exponent // 2)
if half_exp > sys.maxsize - 1: if half_exp is int_oo:
return sys.maxsize - 1 return int_oo
# TODO: microoptimization is to avoid overflowing into arbitrary precision
# and detect overflow prior to doing operations
result = half_exp * half_exp result = half_exp * half_exp
if result > sys.maxsize - 1: if result > sys.maxsize:
return sys.maxsize - 1 return int_oo
if exponent % 2 == 1: if exponent % 2 == 1:
result *= base result *= base
if result > sys.maxsize - 1: if result > sys.maxsize:
return sys.maxsize - 1 return int_oo
return result return result
@ -447,14 +479,20 @@ class PowByNatural(sympy.Function):
@classmethod @classmethod
def eval(cls, base, exp): def eval(cls, base, exp):
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
return sympy.Integer(safe_pow(base, exp)) r = safe_pow(base, exp)
if r in (-int_oo, int_oo):
return r
return sympy.Integer(r)
if isinstance(exp, sympy.Integer): if isinstance(exp, sympy.Integer):
# Translate power into iterated multiplication # Rely on regular sympy Pow for this (note that iterated
r = sympy.Integer(1) # multiplication turns into a Pow anyway, you can't escape!!)
for _ in range(int(exp)): return sympy.Pow(base, exp)
r *= base if exp in (int_oo, sympy.oo):
return r if base.is_nonnegative:
return int_oo
elif base.is_negative:
return sympy.zoo # this is apparently what (-2)**sympy.oo does
# NB: do NOT translate into sympy.Pow, we will lose knowledge that exp # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
# is a natural number if we do # is a natural number if we do
@ -467,6 +505,11 @@ class FloatPow(sympy.Function):
@classmethod @classmethod
def eval(cls, base, exp): def eval(cls, base, exp):
# NB: These test sympy.Number, not sympy.Float, because:
# - Sometimes we may have sympy.oo or int_oo, and that's not a Float
# (but coerces to math.Inf)
# - Sometimes Float(0.0) will unpredictably decay to Integer(0),
# but we should still accept it in floatey contexts
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
return sympy.Float(float(base) ** float(exp)) return sympy.Float(float(base) ** float(exp))
# NB: do not do any nontrivial reasoning # NB: do not do any nontrivial reasoning
@ -510,7 +553,18 @@ class IntTrueDiv(sympy.Function):
if divisor.is_zero: if divisor.is_zero:
raise ZeroDivisionError("division by zero") raise ZeroDivisionError("division by zero")
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): if (
isinstance(base, sympy.Number)
and isinstance(divisor, sympy.Number)
and (
base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
)
):
# Don't have to worry about precision here, you're getting zero or
# inf from the division
return sympy.Float(float(base) / float(divisor))
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return sympy.Float(int(base) / int(divisor)) return sympy.Float(int(base) / int(divisor))
@ -567,10 +621,10 @@ class TruncToInt(sympy.Function):
@classmethod @classmethod
def eval(cls, number): def eval(cls, number):
# assert number.is_integer is not True, number # assert number.is_integer is not True, number
if number == sympy.oo: if number in (sympy.oo, int_oo):
return sympy.Integer(sys.maxsize - 1) return int_oo
if number == -sympy.oo: if number in (-sympy.oo, -int_oo):
return sympy.Integer(-sys.maxsize - 1) return -int_oo
if isinstance(number, sympy.Number): if isinstance(number, sympy.Number):
return sympy.Integer(math.trunc(float(number))) return sympy.Integer(math.trunc(float(number)))
@ -583,7 +637,11 @@ class RoundToInt(sympy.Function):
def eval(cls, number): def eval(cls, number):
# assert number.is_integer is not True, number # assert number.is_integer is not True, number
if isinstance(number, sympy.Float): if number is sympy.oo:
return int_oo
if number is -sympy.oo:
return -int_oo
if isinstance(number, sympy.Number):
return sympy.Integer(round(float(number), 0)) return sympy.Integer(round(float(number), 0))
@ -610,7 +668,7 @@ class RoundDecimal(sympy.Function):
def eval(cls, number, ndigits): def eval(cls, number, ndigits):
# assert number.is_integer is not True, number # assert number.is_integer is not True, number
if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
return sympy.Float(round(float(number), int(ndigits))) return sympy.Float(round(float(number), int(ndigits)))
@ -625,6 +683,10 @@ class ToFloat(sympy.Function):
if isinstance(number, sympy.Integer): if isinstance(number, sympy.Integer):
return sympy.Float(int(number)) return sympy.Float(int(number))
if number is int_oo:
return sympy.oo
if number is -int_oo:
return -sympy.oo
def make_opaque_unary_fn(name): def make_opaque_unary_fn(name):
@ -655,7 +717,11 @@ def make_opaque_unary_fn(name):
# weird objects but ask silly questions, get silly answers # weird objects but ask silly questions, get silly answers
except OverflowError: except OverflowError:
return getattr(sympy, name)(a) return getattr(sympy, name)(a)
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo]: elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
if a is int_oo:
a = sympy.oo
if a is -int_oo:
a = -sympy.oo
return getattr(sympy, name)(a) return getattr(sympy, name)(a)
return None return None

View File

@ -9,6 +9,7 @@ of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
""" """
import functools import functools
import logging
from typing import Any, Dict, Union from typing import Any, Dict, Union
import sympy import sympy
@ -37,6 +38,9 @@ from .functions import (
) )
log = logging.getLogger(__name__)
# TODO: Dedupe this with SYMPY_INTERP # TODO: Dedupe this with SYMPY_INTERP
@ -157,11 +161,18 @@ def sympy_interp(
else: else:
handler_name = handlers()[expr.func] handler_name = handlers()[expr.func]
handler = getattr(analysis, handler_name) handler = getattr(analysis, handler_name)
if handler_name in ASSOCIATIVE_OPS: try:
assert len(args) > 1 if handler_name in ASSOCIATIVE_OPS:
acc = handler(args[0], args[1]) assert len(args) > 1
for i in range(2, len(args)): acc = handler(args[0], args[1])
acc = handler(acc, args[i]) for i in range(2, len(args)):
return acc acc = handler(acc, args[i])
else: log.debug("%s(%s) -> %s", handler_name, args, acc)
return handler(*args) return acc
else:
r = handler(*args)
log.debug("%s(%s) -> %s", handler_name, args, r)
return r
except Exception:
log.warning("failed while executing %s(%s)", handler_name, args)
raise

View File

@ -0,0 +1,394 @@
import mpmath.libmp as mlib # type: ignore[import-untyped]
import sympy
from sympy import Expr
from sympy.core.decorators import _sympifyit
from sympy.core.expr import AtomicExpr
from sympy.core.numbers import Number
from sympy.core.parameters import global_parameters
from sympy.core.singleton import S, Singleton
class IntInfinity(Number, metaclass=Singleton):
r"""Positive integer infinite quantity.
Integer infinity is a value in an extended integers which
is greater than all other integers. We distinguish it from
sympy's existing notion of infinity in that it reports that
it is_integer.
Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
or can be imported as ``int_oo``.
"""
# NB: We can't actually mark this as infinite, as integer and infinite are
# inconsistent assumptions in sympy. We also report that we are complex,
# different from sympy.oo
is_integer = True
is_commutative = True
is_number = True
is_extended_real = True
is_comparable = True
is_extended_positive = True
is_prime = False
# Ensure we get dispatched to before plain numbers
_op_priority = 100.0
__slots__ = ()
def __new__(cls):
return AtomicExpr.__new__(cls)
def _sympystr(self, printer):
return "int_oo"
def _eval_subs(self, old, new):
if self == old:
return new
# We could do these, not sure about it
"""
def _eval_evalf(self, prec=None):
return Float('inf')
def evalf(self, prec=None, **options):
return self._eval_evalf(prec)
"""
@_sympifyit("other", NotImplemented)
def __add__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.NegativeInfinity:
return S.NegativeInfinity
if other in (S.NegativeIntInfinity, S.NaN):
return S.NaN
return self
return Number.__add__(self, other)
__radd__ = __add__
@_sympifyit("other", NotImplemented)
def __sub__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.Infinity:
return S.NegativeInfinity
if other in (S.IntInfinity, S.NaN):
return S.NaN
return self
return Number.__sub__(self, other)
@_sympifyit("other", NotImplemented)
def __rsub__(self, other):
return (-self).__add__(other)
@_sympifyit("other", NotImplemented)
def __mul__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other.is_zero or other is S.NaN:
return S.NaN
if other.is_extended_positive:
return self
return S.NegativeIntInfinity
return Number.__mul__(self, other)
__rmul__ = __mul__
@_sympifyit("other", NotImplemented)
def __truediv__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (
S.Infinity,
S.IntInfinity,
S.NegativeInfinity,
S.NegativeIntInfinity,
S.NaN,
):
return S.NaN
if other.is_extended_nonnegative:
return S.Infinity # truediv produces float
return S.NegativeInfinity # truediv produces float
return Number.__truediv__(self, other)
def __abs__(self):
return S.IntInfinity
def __neg__(self):
return S.NegativeIntInfinity
def _eval_power(self, expt):
if expt.is_extended_positive:
return S.IntInfinity
if expt.is_extended_negative:
return S.Zero
if expt is S.NaN:
return S.NaN
if expt is S.ComplexInfinity:
return S.NaN
if expt.is_extended_real is False and expt.is_number:
from sympy.functions.elementary.complexes import re
expt_real = re(expt)
if expt_real.is_positive:
return S.ComplexInfinity
if expt_real.is_negative:
return S.Zero
if expt_real.is_zero:
return S.NaN
return self ** expt.evalf()
def _as_mpf_val(self, prec):
return mlib.finf
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
return other is S.IntInfinity
def __ne__(self, other):
return other is not S.IntInfinity
def __gt__(self, other):
if other is S.Infinity:
return sympy.false # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.true
def __ge__(self, other):
if other is S.Infinity:
return sympy.false # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.true
def __lt__(self, other):
if other is S.Infinity:
return sympy.true # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.false
def __le__(self, other):
if other is S.Infinity:
return sympy.true # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.false
@_sympifyit("other", NotImplemented)
def __mod__(self, other):
if not isinstance(other, Expr):
return NotImplemented
return S.NaN
__rmod__ = __mod__
def floor(self):
return self
def ceiling(self):
return self
int_oo = S.IntInfinity
class NegativeIntInfinity(Number, metaclass=Singleton):
"""Negative integer infinite quantity.
NegativeInfinity is a singleton, and can be accessed
by ``S.NegativeInfinity``.
See Also
========
IntInfinity
"""
# Ensure we get dispatched to before plain numbers
_op_priority = 100.0
is_integer = True
is_extended_real = True
is_commutative = True
is_comparable = True
is_extended_negative = True
is_number = True
is_prime = False
__slots__ = ()
def __new__(cls):
return AtomicExpr.__new__(cls)
def _eval_subs(self, old, new):
if self == old:
return new
def _sympystr(self, printer):
return "-int_oo"
"""
def _eval_evalf(self, prec=None):
return Float('-inf')
def evalf(self, prec=None, **options):
return self._eval_evalf(prec)
"""
@_sympifyit("other", NotImplemented)
def __add__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.Infinity:
return S.Infinity
if other in (S.IntInfinity, S.NaN):
return S.NaN
return self
return Number.__add__(self, other)
__radd__ = __add__
@_sympifyit("other", NotImplemented)
def __sub__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.NegativeInfinity:
return S.Infinity
if other in (S.NegativeIntInfinity, S.NaN):
return S.NaN
return self
return Number.__sub__(self, other)
@_sympifyit("other", NotImplemented)
def __rsub__(self, other):
return (-self).__add__(other)
@_sympifyit("other", NotImplemented)
def __mul__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other.is_zero or other is S.NaN:
return S.NaN
if other.is_extended_positive:
return self
return S.IntInfinity
return Number.__mul__(self, other)
__rmul__ = __mul__
@_sympifyit("other", NotImplemented)
def __truediv__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (
S.Infinity,
S.IntInfinity,
S.NegativeInfinity,
S.NegativeIntInfinity,
S.NaN,
):
return S.NaN
if other.is_extended_nonnegative:
return self
return S.Infinity # truediv returns float
return Number.__truediv__(self, other)
def __abs__(self):
return S.IntInfinity
def __neg__(self):
return S.IntInfinity
def _eval_power(self, expt):
if expt.is_number:
if expt in (
S.NaN,
S.Infinity,
S.NegativeInfinity,
S.IntInfinity,
S.NegativeIntInfinity,
):
return S.NaN
if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
if expt.is_odd:
return S.NegativeIntInfinity
else:
return S.IntInfinity
inf_part = S.IntInfinity**expt
s_part = S.NegativeOne**expt
if inf_part == 0 and s_part.is_finite:
return inf_part
if (
inf_part is S.ComplexInfinity
and s_part.is_finite
and not s_part.is_zero
):
return S.ComplexInfinity
return s_part * inf_part
def _as_mpf_val(self, prec):
return mlib.fninf
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
return other is S.NegativeIntInfinity
def __ne__(self, other):
return other is not S.NegativeIntInfinity
def __gt__(self, other):
if other is S.NegativeInfinity:
return sympy.true # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.false
def __ge__(self, other):
if other is S.NegativeInfinity:
return sympy.true # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.false
def __lt__(self, other):
if other is S.NegativeInfinity:
return sympy.false # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.true
def __le__(self, other):
if other is S.NegativeInfinity:
return sympy.false # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.true
@_sympifyit("other", NotImplemented)
def __mod__(self, other):
if not isinstance(other, Expr):
return NotImplemented
return S.NaN
__rmod__ = __mod__
def floor(self):
return self
def ceiling(self):
return self
def as_powers_dict(self):
return {S.NegativeOne: 1, S.IntInfinity: 1}

View File

@ -6,7 +6,6 @@ import itertools
import logging import logging
import math import math
import operator import operator
import sys
from typing import ( from typing import (
Callable, Callable,
Dict, Dict,
@ -24,6 +23,7 @@ import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch import torch
from torch._logging import LazyString
from torch._prims_common import dtype_to_type from torch._prims_common import dtype_to_type
from .functions import ( from .functions import (
@ -43,6 +43,7 @@ from .functions import (
TruncToInt, TruncToInt,
) )
from .interp import sympy_interp from .interp import sympy_interp
from .numbers import int_oo, IntInfinity, NegativeIntInfinity
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -168,7 +169,10 @@ class ValueRanges(Generic[_T]):
self, self,
"is_int", "is_int",
not self.is_bool not self.is_bool
and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), and (
isinstance(lower, (sympy.Integer, NegativeIntInfinity))
or isinstance(upper, (sympy.Integer, IntInfinity))
),
) )
""" """
# This assert is just impossible right now, too many sympy bugs # This assert is just impossible right now, too many sympy bugs
@ -265,11 +269,14 @@ class ValueRanges(Generic[_T]):
def is_singleton(self) -> bool: def is_singleton(self) -> bool:
return self.lower == self.upper return self.lower == self.upper
# TODO: this doesn't work with bools but arguably it should
@staticmethod @staticmethod
def unknown() -> ValueRanges[sympy.Expr]: def unknown() -> ValueRanges[sympy.Expr]:
return ValueRanges(-sympy.oo, sympy.oo) return ValueRanges(-sympy.oo, sympy.oo)
@staticmethod
def unknown_int() -> ValueRanges[sympy.Expr]:
return ValueRanges(-int_oo, int_oo)
@staticmethod @staticmethod
def unknown_bool() -> ValueRanges[SympyBoolean]: def unknown_bool() -> ValueRanges[SympyBoolean]:
return ValueRanges(sympy.false, sympy.true) return ValueRanges(sympy.false, sympy.true)
@ -401,7 +408,7 @@ class SymPyValueRangeAnalysis:
elif dtype.is_floating_point: elif dtype.is_floating_point:
return ValueRanges.unknown() return ValueRanges.unknown()
else: else:
return ValueRanges(-sys.maxsize - 1, sys.maxsize) return ValueRanges(-int_oo, int_oo)
if is_python: if is_python:
type_ = dtype_to_type(dtype) type_ = dtype_to_type(dtype)
@ -424,6 +431,10 @@ class SymPyValueRangeAnalysis:
def to_dtype(a, dtype, src_dtype=None): def to_dtype(a, dtype, src_dtype=None):
if dtype == torch.float64: if dtype == torch.float64:
return ValueRanges.increasing_map(a, ToFloat) return ValueRanges.increasing_map(a, ToFloat)
elif dtype == torch.bool:
return ValueRanges.unknown_bool()
elif not dtype.is_floating_point:
return ValueRanges.unknown_int()
return ValueRanges.unknown() return ValueRanges.unknown()
@staticmethod @staticmethod
@ -515,9 +526,7 @@ class SymPyValueRangeAnalysis:
def int_truediv(a, b): def int_truediv(a, b):
a = ValueRanges.wrap(a) a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b) b = ValueRanges.wrap(b)
if 0 in b or ( if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)):
(-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
):
return ValueRanges.unknown() return ValueRanges.unknown()
else: else:
return ValueRanges.coordinatewise_monotone_map( return ValueRanges.coordinatewise_monotone_map(
@ -541,14 +550,17 @@ class SymPyValueRangeAnalysis:
def floordiv(a, b): def floordiv(a, b):
a = ValueRanges.wrap(a) a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b) b = ValueRanges.wrap(b)
if 0 in b or ( if 0 in b:
# TODO: make this more precise
(-sympy.oo in a or sympy.oo in a)
or (-sympy.oo in b or sympy.oo in b)
):
return ValueRanges.unknown() return ValueRanges.unknown()
else: products = []
return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]):
r = FloorDiv(x, y)
if r is sympy.nan:
products.append((sympy.sign(x) * sympy.sign(y)) * int_oo)
else:
products.append(r)
return ValueRanges(min(products), max(products))
@classmethod @classmethod
def mod(cls, x, y): def mod(cls, x, y):
@ -564,10 +576,10 @@ class SymPyValueRangeAnalysis:
def c_div(a, b): def c_div(a, b):
x = a / b x = a / b
return sympy.Integer(x) if x.is_finite else x return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x
if 0 in y: if 0 in y:
return ValueRanges.unknown() return ValueRanges.unknown_int()
elif y.is_singleton(): elif y.is_singleton():
y_val = abs(y.lower) y_val = abs(y.lower)
# If it wraps, we need to take the whole interval # If it wraps, we need to take the whole interval
@ -597,7 +609,7 @@ class SymPyValueRangeAnalysis:
@classmethod @classmethod
def is_non_overlapping_and_dense_indicator(cls, *args): def is_non_overlapping_and_dense_indicator(cls, *args):
return ValueRanges.unknown() # TODO: type here is wrong return ValueRanges.unknown_int()
@classmethod @classmethod
def pow_by_natural(cls, a, b): def pow_by_natural(cls, a, b):
@ -611,7 +623,7 @@ class SymPyValueRangeAnalysis:
# to replacements, so don't assert it, but DO clamp it to prevent # to replacements, so don't assert it, but DO clamp it to prevent
# degenerate problems # degenerate problems
return ValueRanges.coordinatewise_increasing_map( return ValueRanges.coordinatewise_increasing_map(
a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural a, b & ValueRanges(0, int_oo), PowByNatural
) )
elif b.is_singleton(): elif b.is_singleton():
if b.lower % 2 == 0: if b.lower % 2 == 0:
@ -939,6 +951,8 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
if dtype.is_floating_point: if dtype.is_floating_point:
return sympy.Float(x) return sympy.Float(x)
else: else:
if x in (int_oo, -int_oo):
return x
try: try:
return sympy.Integer(x) return sympy.Integer(x)
except TypeError: except TypeError:
@ -986,7 +1000,18 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def bound_sympy( def bound_sympy(
expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None
) -> ValueRanges: ) -> ValueRanges:
log.debug("bound_sympy(%s, %s)", expr, ranges) log.debug(
"bound_sympy(%s)%s",
expr,
LazyString(
lambda: "\n"
+ "\n".join(
f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
)
if ranges
else ""
),
)
if isinstance(expr, sympy.Number): if isinstance(expr, sympy.Number):
return ValueRanges.wrap(expr) return ValueRanges.wrap(expr)