From 10c831567b36717c4155b47c72db13f7b924815b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 5 Jul 2024 13:35:15 -0700 Subject: [PATCH] Make sympify'ing SymInt/etc produce their sympy expression (#130166) There is one huge problem this fixes: today, sympify(symint) produces a float(!!) because Sympy attempts to see if you can coerce the symint to float in sympify and of course this works on SymInt. However, this also has another nontrivial effect: anywhere in Inductor where sympy expressions are passed around, it is also valid to pass around a SymInt now. I'm ambivalent about this: it's currently a mistake to be passing around a SymInt when a sympy expression is expected. But maybe this is fine? Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/130166 Approved by: https://github.com/yf225 --- test/test_dynamic_shapes.py | 9 +++++++++ torch/__init__.py | 9 +++++++++ torch/_inductor/utils.py | 4 +--- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 2cef46bc2fe..e87937bc7d6 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -260,6 +260,15 @@ class TestPySymInt(TestCase): a = create_symint(shape_env, 2) self.assertTrue(5 * a == 5 * 2) + def test_sympify_symint(self): + shape_env = ShapeEnv() + a = create_symint(shape_env, 2) + self.assertIs(sympy.sympify(a), a.node.expr) + b = create_symfloat(shape_env, 3.0) + self.assertIs(sympy.sympify(b), b.node.expr) + c = create_symbool(shape_env, True) + self.assertIs(sympy.sympify(c), c.node.expr) + def test_roundtrip(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) diff --git a/torch/__init__.py b/torch/__init__.py index a6c731cd2c1..9ad49b253ff 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -507,6 +507,9 @@ class SymInt: def __repr__(self): return str(self.node) + def _sympy_(self): + return self.node.expr + def __hash__(self) -> builtins.int: if self.node.is_nested_int(): return hash(self.node.nested_int()) @@ -615,6 +618,9 @@ class SymFloat: def __repr__(self): return self.node.str() + def _sympy_(self): + return self.node.expr + def __hash__(self): if self.node.is_constant(): return hash(self.node.float_()) @@ -680,6 +686,9 @@ class SymBool: def __repr__(self): return str(self.node) + def _sympy_(self): + return self.node.expr + def __hash__(self): if self.node.is_constant(): return hash(self.node.bool_()) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 2d93114cf71..7c7d63a4bc8 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -281,9 +281,7 @@ def convert_shape_to_inductor( trivial. But for symbolic tensors, we need to map from SymIntNode into sympy.Expr. """ - return [ - i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst - ] + return [sympy.sympify(i) for i in lst] def convert_shape_to_symint(