pytorch/test/inductor/test_utils.py
Xuehai Pan 7763c83af6 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
2024-05-27 04:22:18 +00:00

57 lines
1.9 KiB
Python

# Owner(s): ["module: inductor"]
from sympy import Symbol
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import sympy_subs
class TestUtils(TestCase):
def testSympySubs(self):
# integer and nonnegetaive attributes are preserved.
expr = Symbol("x")
result = sympy_subs(expr, {expr: "y"})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, None)
self.assertEqual(result.is_nonnegative, None)
expr = Symbol("x", integer=True, nonnegative=False)
result = sympy_subs(expr, {expr: "y"})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, True)
self.assertEqual(result.is_nonnegative, False)
# invalid replacement.
expr = Symbol("x", integer=True)
result = sympy_subs(expr, {Symbol("x"): Symbol("y")})
self.assertEqual(result.name, "x")
# valid replacement since properties match.
expr = Symbol("x", integer=True)
result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")})
self.assertEqual(result.name, "y")
# invalid replacement.
expr = Symbol("x", integer=None)
result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")})
self.assertEqual(result.name, "x")
# replaced cant be string
self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"})
# replaced can be an expression
expr = Symbol("x")
expr = abs(expr)
self.assertEqual(expr.is_integer, None)
self.assertEqual(expr.is_nonnegative, None)
# replace abs(x) with y
# propagte abs(x) sympy properties.
result = sympy_subs(expr, {expr: Symbol("y")})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, None)
self.assertEqual(result.is_nonnegative, None)
if __name__ == "__main__":
run_tests()