mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Add symbolic singleton int (#110370)"
This reverts commit a7145cb3a4.
Reverted https://github.com/pytorch/pytorch/pull/110370 on behalf of https://github.com/PaliC due to bottom diff is causing a plethora of internal failures ([comment](https://github.com/pytorch/pytorch/pull/110370#issuecomment-1749801188))
This commit is contained in:
parent
585e2bd818
commit
fdf6055ea7
|
|
@ -18,10 +18,6 @@ from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||||
from torch.utils._sympy.reference import ReferenceAnalysis
|
from torch.utils._sympy.reference import ReferenceAnalysis
|
||||||
from torch.utils._sympy.interp import sympy_interp
|
from torch.utils._sympy.interp import sympy_interp
|
||||||
from torch.utils._sympy.singleton_int import SingletonInt
|
|
||||||
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
|
|
||||||
import functools
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
UNARY_OPS = [
|
UNARY_OPS = [
|
||||||
|
|
@ -524,89 +520,6 @@ class TestSympySolve(TestCase):
|
||||||
r = solver.check()
|
r = solver.check()
|
||||||
self.assertEqual(r, z3.unsat)
|
self.assertEqual(r, z3.unsat)
|
||||||
|
|
||||||
class TestSingletonInt(TestCase):
|
|
||||||
def test_basic(self):
|
|
||||||
j1 = SingletonInt(1, coeff=1)
|
|
||||||
j1_copy = SingletonInt(1, coeff=1)
|
|
||||||
j2 = SingletonInt(2, coeff=1)
|
|
||||||
j1x2 = SingletonInt(1, coeff=2)
|
|
||||||
|
|
||||||
def test_eq(a, b, expected):
|
|
||||||
self.assertEqual(sympy.Eq(a, b), expected)
|
|
||||||
self.assertEqual(sympy.Ne(b, a), not expected)
|
|
||||||
|
|
||||||
# eq, ne
|
|
||||||
test_eq(j1, j1, True)
|
|
||||||
test_eq(j1, j1_copy, True)
|
|
||||||
test_eq(j1, j2, False)
|
|
||||||
test_eq(j1, j1x2, False)
|
|
||||||
test_eq(j1, sympy.Integer(1), False)
|
|
||||||
test_eq(j1, sympy.Integer(3), False)
|
|
||||||
|
|
||||||
def test_ineq(a, b, expected, *, strict=True):
|
|
||||||
greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge)
|
|
||||||
less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le)
|
|
||||||
|
|
||||||
if isinstance(expected, bool):
|
|
||||||
# expected is always True
|
|
||||||
for fn in greater:
|
|
||||||
self.assertEqual(fn(a, b), expected)
|
|
||||||
self.assertEqual(fn(b, a), not expected)
|
|
||||||
for fn in less:
|
|
||||||
self.assertEqual(fn(b, a), expected)
|
|
||||||
self.assertEqual(fn(a, b), not expected)
|
|
||||||
else:
|
|
||||||
for fn in greater:
|
|
||||||
with self.assertRaisesRegex(ValueError, expected):
|
|
||||||
fn(a, b)
|
|
||||||
for fn in less:
|
|
||||||
with self.assertRaisesRegex(ValueError, expected):
|
|
||||||
fn(b, a)
|
|
||||||
|
|
||||||
# ge, le, gt, lt
|
|
||||||
for strict in (True, False):
|
|
||||||
_test_ineq = functools.partial(test_ineq, strict=strict)
|
|
||||||
_test_ineq(j1, sympy.Integer(0), True)
|
|
||||||
_test_ineq(j1, sympy.Integer(3), "indeterminate")
|
|
||||||
_test_ineq(j1, j2, "indeterminate")
|
|
||||||
_test_ineq(j1x2, j1, True)
|
|
||||||
|
|
||||||
# Special cases for ge, le, gt, lt:
|
|
||||||
for ge in (sympy.Ge, is_ge):
|
|
||||||
self.assertTrue(ge(j1, j1))
|
|
||||||
self.assertTrue(ge(j1, sympy.Integer(2)))
|
|
||||||
with self.assertRaisesRegex(ValueError, "indeterminate"):
|
|
||||||
ge(sympy.Integer(2), j1)
|
|
||||||
for le in (sympy.Le, is_le):
|
|
||||||
self.assertTrue(le(j1, j1))
|
|
||||||
self.assertTrue(le(sympy.Integer(2), j1))
|
|
||||||
with self.assertRaisesRegex(ValueError, "indeterminate"):
|
|
||||||
le(j1, sympy.Integer(2))
|
|
||||||
|
|
||||||
for gt in (sympy.Gt, is_gt):
|
|
||||||
self.assertFalse(gt(j1, j1))
|
|
||||||
self.assertFalse(gt(sympy.Integer(2), j1))
|
|
||||||
# it is only known to be that j1 >= 2, j1 > 2 is indeterminate
|
|
||||||
with self.assertRaisesRegex(ValueError, "indeterminate"):
|
|
||||||
gt(j1, sympy.Integer(2))
|
|
||||||
|
|
||||||
for lt in (sympy.Lt, is_lt):
|
|
||||||
self.assertFalse(lt(j1, j1))
|
|
||||||
self.assertFalse(lt(j1, sympy.Integer(2)))
|
|
||||||
with self.assertRaisesRegex(ValueError, "indeterminate"):
|
|
||||||
lt(sympy.Integer(2), j1)
|
|
||||||
|
|
||||||
# mul
|
|
||||||
self.assertEqual(j1 * 2, j1x2)
|
|
||||||
# Unfortunately, this doesn't not automatically simplify to 2*j1
|
|
||||||
# since sympy.Mul doesn't trigger __mul__ unlike the above.
|
|
||||||
self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "cannot be multiplied"):
|
|
||||||
j1 * j2
|
|
||||||
|
|
||||||
self.assertEqual(j1.free_symbols, set())
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestValueRanges)
|
instantiate_parametrized_tests(TestValueRanges)
|
||||||
instantiate_parametrized_tests(TestSympyInterp)
|
instantiate_parametrized_tests(TestSympyInterp)
|
||||||
|
|
|
||||||
|
|
@ -1,94 +0,0 @@
|
||||||
import sympy
|
|
||||||
from sympy.multipledispatch import dispatch
|
|
||||||
|
|
||||||
__all__ = ["SingletonInt"]
|
|
||||||
|
|
||||||
|
|
||||||
class SingletonInt(sympy.AtomicExpr):
|
|
||||||
# This is probably not super important unless we are in multiple dispatch
|
|
||||||
# situations with other more exotic Expr types.
|
|
||||||
_op_priority = 99999
|
|
||||||
|
|
||||||
def __new__(cls, *args, coeff=None, **kwargs):
|
|
||||||
instance = super().__new__(cls, *args, **kwargs)
|
|
||||||
return instance
|
|
||||||
|
|
||||||
# The semantics of this class should match that of SingletonSymNodeImpl in
|
|
||||||
# c10/core/SingletonSymNodeImpl.h
|
|
||||||
def __init__(self, val, *, coeff=1):
|
|
||||||
self._val = val
|
|
||||||
self._coeff = coeff
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# See NOTE [ Inequalities with SingletonInt ]
|
|
||||||
def _eval_Eq(self, other):
|
|
||||||
if (
|
|
||||||
isinstance(other, SingletonInt)
|
|
||||||
and other._val == self._val
|
|
||||||
and self._coeff == other._coeff
|
|
||||||
):
|
|
||||||
return sympy.true
|
|
||||||
else:
|
|
||||||
return sympy.false
|
|
||||||
|
|
||||||
# This is necessary so that calling expr.free_symbols on exprs that contain
|
|
||||||
# this Singleton does not error
|
|
||||||
@property
|
|
||||||
def free_symbols(self):
|
|
||||||
return set()
|
|
||||||
|
|
||||||
def __mul__(self, other):
|
|
||||||
if isinstance(other, SingletonInt):
|
|
||||||
raise ValueError(
|
|
||||||
"SingletonInt cannot be multiplied by another SingletonInt"
|
|
||||||
)
|
|
||||||
return SingletonInt(self._val, coeff=self._coeff * other)
|
|
||||||
|
|
||||||
def __rmul__(self, other):
|
|
||||||
if isinstance(other, SingletonInt):
|
|
||||||
raise ValueError(
|
|
||||||
"SingletonInt cannot be multiplied by another SingletonInt"
|
|
||||||
)
|
|
||||||
return SingletonInt(self._val, coeff=self._coeff * other)
|
|
||||||
|
|
||||||
# Make sure we promptly raise an error instead of falling back to building
|
|
||||||
# an expression tree. There are probably more ops, how can we be exhaustive?
|
|
||||||
def __add__(self, other):
|
|
||||||
raise NotImplementedError("NYI")
|
|
||||||
|
|
||||||
def __sub__(self, other):
|
|
||||||
raise NotImplementedError("NYI")
|
|
||||||
|
|
||||||
def __truediv__(self, other):
|
|
||||||
raise NotImplementedError("NYI")
|
|
||||||
|
|
||||||
def __floordiv__(self, other):
|
|
||||||
raise NotImplementedError("NYI")
|
|
||||||
|
|
||||||
def __mod__(self, other):
|
|
||||||
raise NotImplementedError("NYI")
|
|
||||||
|
|
||||||
|
|
||||||
# See NOTE [ Inequalities with SingletonInt ]
|
|
||||||
@dispatch(sympy.Integer, SingletonInt)
|
|
||||||
def _eval_is_ge(a, b):
|
|
||||||
if a < 2:
|
|
||||||
return sympy.false
|
|
||||||
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
|
|
||||||
|
|
||||||
|
|
||||||
@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef]
|
|
||||||
def _eval_is_ge(a, b): # noqa: F811
|
|
||||||
if b <= 2:
|
|
||||||
return sympy.true
|
|
||||||
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
|
|
||||||
|
|
||||||
|
|
||||||
@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef]
|
|
||||||
def _eval_is_ge(a, b): # noqa: F811
|
|
||||||
if a._val == b._val:
|
|
||||||
if a._coeff >= b._coeff:
|
|
||||||
return sympy.true
|
|
||||||
else:
|
|
||||||
return sympy.false
|
|
||||||
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
|
|
||||||
Loading…
Reference in New Issue
Block a user