Revert "Add symbolic singleton int (#110370)"

This reverts commit a7145cb3a4.

Reverted https://github.com/pytorch/pytorch/pull/110370 on behalf of https://github.com/PaliC due to bottom diff is causing a plethora of internal failures ([comment](https://github.com/pytorch/pytorch/pull/110370#issuecomment-1749801188))
This commit is contained in:
PyTorch MergeBot 2023-10-05 23:47:09 +00:00
parent 585e2bd818
commit fdf6055ea7
2 changed files with 0 additions and 181 deletions

View File

@ -18,10 +18,6 @@ from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools
UNARY_OPS = [
@ -524,89 +520,6 @@ class TestSympySolve(TestCase):
r = solver.check()
self.assertEqual(r, z3.unsat)
class TestSingletonInt(TestCase):
def test_basic(self):
j1 = SingletonInt(1, coeff=1)
j1_copy = SingletonInt(1, coeff=1)
j2 = SingletonInt(2, coeff=1)
j1x2 = SingletonInt(1, coeff=2)
def test_eq(a, b, expected):
self.assertEqual(sympy.Eq(a, b), expected)
self.assertEqual(sympy.Ne(b, a), not expected)
# eq, ne
test_eq(j1, j1, True)
test_eq(j1, j1_copy, True)
test_eq(j1, j2, False)
test_eq(j1, j1x2, False)
test_eq(j1, sympy.Integer(1), False)
test_eq(j1, sympy.Integer(3), False)
def test_ineq(a, b, expected, *, strict=True):
greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge)
less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le)
if isinstance(expected, bool):
# expected is always True
for fn in greater:
self.assertEqual(fn(a, b), expected)
self.assertEqual(fn(b, a), not expected)
for fn in less:
self.assertEqual(fn(b, a), expected)
self.assertEqual(fn(a, b), not expected)
else:
for fn in greater:
with self.assertRaisesRegex(ValueError, expected):
fn(a, b)
for fn in less:
with self.assertRaisesRegex(ValueError, expected):
fn(b, a)
# ge, le, gt, lt
for strict in (True, False):
_test_ineq = functools.partial(test_ineq, strict=strict)
_test_ineq(j1, sympy.Integer(0), True)
_test_ineq(j1, sympy.Integer(3), "indeterminate")
_test_ineq(j1, j2, "indeterminate")
_test_ineq(j1x2, j1, True)
# Special cases for ge, le, gt, lt:
for ge in (sympy.Ge, is_ge):
self.assertTrue(ge(j1, j1))
self.assertTrue(ge(j1, sympy.Integer(2)))
with self.assertRaisesRegex(ValueError, "indeterminate"):
ge(sympy.Integer(2), j1)
for le in (sympy.Le, is_le):
self.assertTrue(le(j1, j1))
self.assertTrue(le(sympy.Integer(2), j1))
with self.assertRaisesRegex(ValueError, "indeterminate"):
le(j1, sympy.Integer(2))
for gt in (sympy.Gt, is_gt):
self.assertFalse(gt(j1, j1))
self.assertFalse(gt(sympy.Integer(2), j1))
# it is only known to be that j1 >= 2, j1 > 2 is indeterminate
with self.assertRaisesRegex(ValueError, "indeterminate"):
gt(j1, sympy.Integer(2))
for lt in (sympy.Lt, is_lt):
self.assertFalse(lt(j1, j1))
self.assertFalse(lt(j1, sympy.Integer(2)))
with self.assertRaisesRegex(ValueError, "indeterminate"):
lt(sympy.Integer(2), j1)
# mul
self.assertEqual(j1 * 2, j1x2)
# Unfortunately, this doesn't not automatically simplify to 2*j1
# since sympy.Mul doesn't trigger __mul__ unlike the above.
self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul)
with self.assertRaisesRegex(ValueError, "cannot be multiplied"):
j1 * j2
self.assertEqual(j1.free_symbols, set())
instantiate_parametrized_tests(TestValueRanges)
instantiate_parametrized_tests(TestSympyInterp)

View File

@ -1,94 +0,0 @@
import sympy
from sympy.multipledispatch import dispatch
__all__ = ["SingletonInt"]
class SingletonInt(sympy.AtomicExpr):
# This is probably not super important unless we are in multiple dispatch
# situations with other more exotic Expr types.
_op_priority = 99999
def __new__(cls, *args, coeff=None, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
return instance
# The semantics of this class should match that of SingletonSymNodeImpl in
# c10/core/SingletonSymNodeImpl.h
def __init__(self, val, *, coeff=1):
self._val = val
self._coeff = coeff
super().__init__()
# See NOTE [ Inequalities with SingletonInt ]
def _eval_Eq(self, other):
if (
isinstance(other, SingletonInt)
and other._val == self._val
and self._coeff == other._coeff
):
return sympy.true
else:
return sympy.false
# This is necessary so that calling expr.free_symbols on exprs that contain
# this Singleton does not error
@property
def free_symbols(self):
return set()
def __mul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
def __rmul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
# Make sure we promptly raise an error instead of falling back to building
# an expression tree. There are probably more ops, how can we be exhaustive?
def __add__(self, other):
raise NotImplementedError("NYI")
def __sub__(self, other):
raise NotImplementedError("NYI")
def __truediv__(self, other):
raise NotImplementedError("NYI")
def __floordiv__(self, other):
raise NotImplementedError("NYI")
def __mod__(self, other):
raise NotImplementedError("NYI")
# See NOTE [ Inequalities with SingletonInt ]
@dispatch(sympy.Integer, SingletonInt)
def _eval_is_ge(a, b):
if a < 2:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if b <= 2:
return sympy.true
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if a._val == b._val:
if a._coeff >= b._coeff:
return sympy.true
else:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")