mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit 7763c83af6.
Reverted https://github.com/pytorch/pytorch/pull/127126 on behalf of https://github.com/XuehaiPan due to Broken CI ([comment](https://github.com/pytorch/pytorch/pull/127126#issuecomment-2133044286))
56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
from sympy import Symbol
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import sympy_subs
|
|
|
|
|
|
class TestUtils(TestCase):
|
|
def testSympySubs(self):
|
|
# integer and nonnegetaive attributes are preserved.
|
|
expr = Symbol("x")
|
|
result = sympy_subs(expr, {expr: "y"})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, None)
|
|
self.assertEqual(result.is_nonnegative, None)
|
|
|
|
expr = Symbol("x", integer=True, nonnegative=False)
|
|
result = sympy_subs(expr, {expr: "y"})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, True)
|
|
self.assertEqual(result.is_nonnegative, False)
|
|
|
|
# invalid replacement.
|
|
expr = Symbol("x", integer=True)
|
|
result = sympy_subs(expr, {Symbol("x"): Symbol("y")})
|
|
self.assertEqual(result.name, "x")
|
|
|
|
# valid replacement since properties match.
|
|
expr = Symbol("x", integer=True)
|
|
result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")})
|
|
self.assertEqual(result.name, "y")
|
|
|
|
# invalid replacement.
|
|
expr = Symbol("x", integer=None)
|
|
result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")})
|
|
self.assertEqual(result.name, "x")
|
|
|
|
# replaced cant be string
|
|
self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"})
|
|
|
|
# replaced can be an expression
|
|
expr = Symbol("x")
|
|
expr = abs(expr)
|
|
self.assertEqual(expr.is_integer, None)
|
|
self.assertEqual(expr.is_nonnegative, None)
|
|
# replace abs(x) with y
|
|
# propagte abs(x) sympy properties.
|
|
result = sympy_subs(expr, {expr: Symbol("y")})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, None)
|
|
self.assertEqual(result.is_nonnegative, None)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|