diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 474d70db690..72bcd35e0a5 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -18,10 +18,6 @@ from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges from torch.utils._sympy.reference import ReferenceAnalysis from torch.utils._sympy.interp import sympy_interp -from torch.utils._sympy.singleton_int import SingletonInt -from sympy.core.relational import is_ge, is_le, is_gt, is_lt -import functools - UNARY_OPS = [ @@ -524,89 +520,6 @@ class TestSympySolve(TestCase): r = solver.check() self.assertEqual(r, z3.unsat) -class TestSingletonInt(TestCase): - def test_basic(self): - j1 = SingletonInt(1, coeff=1) - j1_copy = SingletonInt(1, coeff=1) - j2 = SingletonInt(2, coeff=1) - j1x2 = SingletonInt(1, coeff=2) - - def test_eq(a, b, expected): - self.assertEqual(sympy.Eq(a, b), expected) - self.assertEqual(sympy.Ne(b, a), not expected) - - # eq, ne - test_eq(j1, j1, True) - test_eq(j1, j1_copy, True) - test_eq(j1, j2, False) - test_eq(j1, j1x2, False) - test_eq(j1, sympy.Integer(1), False) - test_eq(j1, sympy.Integer(3), False) - - def test_ineq(a, b, expected, *, strict=True): - greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge) - less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le) - - if isinstance(expected, bool): - # expected is always True - for fn in greater: - self.assertEqual(fn(a, b), expected) - self.assertEqual(fn(b, a), not expected) - for fn in less: - self.assertEqual(fn(b, a), expected) - self.assertEqual(fn(a, b), not expected) - else: - for fn in greater: - with self.assertRaisesRegex(ValueError, expected): - fn(a, b) - for fn in less: - with self.assertRaisesRegex(ValueError, expected): - fn(b, a) - - # ge, le, gt, lt - for strict in (True, False): - _test_ineq = functools.partial(test_ineq, strict=strict) - _test_ineq(j1, sympy.Integer(0), True) - _test_ineq(j1, sympy.Integer(3), "indeterminate") - _test_ineq(j1, j2, "indeterminate") - _test_ineq(j1x2, j1, True) - - # Special cases for ge, le, gt, lt: - for ge in (sympy.Ge, is_ge): - self.assertTrue(ge(j1, j1)) - self.assertTrue(ge(j1, sympy.Integer(2))) - with self.assertRaisesRegex(ValueError, "indeterminate"): - ge(sympy.Integer(2), j1) - for le in (sympy.Le, is_le): - self.assertTrue(le(j1, j1)) - self.assertTrue(le(sympy.Integer(2), j1)) - with self.assertRaisesRegex(ValueError, "indeterminate"): - le(j1, sympy.Integer(2)) - - for gt in (sympy.Gt, is_gt): - self.assertFalse(gt(j1, j1)) - self.assertFalse(gt(sympy.Integer(2), j1)) - # it is only known to be that j1 >= 2, j1 > 2 is indeterminate - with self.assertRaisesRegex(ValueError, "indeterminate"): - gt(j1, sympy.Integer(2)) - - for lt in (sympy.Lt, is_lt): - self.assertFalse(lt(j1, j1)) - self.assertFalse(lt(j1, sympy.Integer(2))) - with self.assertRaisesRegex(ValueError, "indeterminate"): - lt(sympy.Integer(2), j1) - - # mul - self.assertEqual(j1 * 2, j1x2) - # Unfortunately, this doesn't not automatically simplify to 2*j1 - # since sympy.Mul doesn't trigger __mul__ unlike the above. - self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul) - - with self.assertRaisesRegex(ValueError, "cannot be multiplied"): - j1 * j2 - - self.assertEqual(j1.free_symbols, set()) - instantiate_parametrized_tests(TestValueRanges) instantiate_parametrized_tests(TestSympyInterp) diff --git a/torch/utils/_sympy/singleton_int.py b/torch/utils/_sympy/singleton_int.py deleted file mode 100644 index d67e3732a27..00000000000 --- a/torch/utils/_sympy/singleton_int.py +++ /dev/null @@ -1,94 +0,0 @@ -import sympy -from sympy.multipledispatch import dispatch - -__all__ = ["SingletonInt"] - - -class SingletonInt(sympy.AtomicExpr): - # This is probably not super important unless we are in multiple dispatch - # situations with other more exotic Expr types. - _op_priority = 99999 - - def __new__(cls, *args, coeff=None, **kwargs): - instance = super().__new__(cls, *args, **kwargs) - return instance - - # The semantics of this class should match that of SingletonSymNodeImpl in - # c10/core/SingletonSymNodeImpl.h - def __init__(self, val, *, coeff=1): - self._val = val - self._coeff = coeff - super().__init__() - - # See NOTE [ Inequalities with SingletonInt ] - def _eval_Eq(self, other): - if ( - isinstance(other, SingletonInt) - and other._val == self._val - and self._coeff == other._coeff - ): - return sympy.true - else: - return sympy.false - - # This is necessary so that calling expr.free_symbols on exprs that contain - # this Singleton does not error - @property - def free_symbols(self): - return set() - - def __mul__(self, other): - if isinstance(other, SingletonInt): - raise ValueError( - "SingletonInt cannot be multiplied by another SingletonInt" - ) - return SingletonInt(self._val, coeff=self._coeff * other) - - def __rmul__(self, other): - if isinstance(other, SingletonInt): - raise ValueError( - "SingletonInt cannot be multiplied by another SingletonInt" - ) - return SingletonInt(self._val, coeff=self._coeff * other) - - # Make sure we promptly raise an error instead of falling back to building - # an expression tree. There are probably more ops, how can we be exhaustive? - def __add__(self, other): - raise NotImplementedError("NYI") - - def __sub__(self, other): - raise NotImplementedError("NYI") - - def __truediv__(self, other): - raise NotImplementedError("NYI") - - def __floordiv__(self, other): - raise NotImplementedError("NYI") - - def __mod__(self, other): - raise NotImplementedError("NYI") - - -# See NOTE [ Inequalities with SingletonInt ] -@dispatch(sympy.Integer, SingletonInt) -def _eval_is_ge(a, b): - if a < 2: - return sympy.false - raise ValueError("Symbolic SingletonInt: Relation is indeterminate") - - -@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef] -def _eval_is_ge(a, b): # noqa: F811 - if b <= 2: - return sympy.true - raise ValueError("Symbolic SingletonInt: Relation is indeterminate") - - -@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef] -def _eval_is_ge(a, b): # noqa: F811 - if a._val == b._val: - if a._coeff >= b._coeff: - return sympy.true - else: - return sympy.false - raise ValueError("Symbolic SingletonInt: Relation is indeterminate")