mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Make sympify'ing SymInt/etc produce their sympy expression (#130166)
There is one huge problem this fixes: today, sympify(symint) produces a float(!!) because Sympy attempts to see if you can coerce the symint to float in sympify and of course this works on SymInt. However, this also has another nontrivial effect: anywhere in Inductor where sympy expressions are passed around, it is also valid to pass around a SymInt now. I'm ambivalent about this: it's currently a mistake to be passing around a SymInt when a sympy expression is expected. But maybe this is fine? Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130166 Approved by: https://github.com/yf225
This commit is contained in:
parent
acd03ca2d9
commit
10c831567b
|
|
@ -260,6 +260,15 @@ class TestPySymInt(TestCase):
|
||||||
a = create_symint(shape_env, 2)
|
a = create_symint(shape_env, 2)
|
||||||
self.assertTrue(5 * a == 5 * 2)
|
self.assertTrue(5 * a == 5 * 2)
|
||||||
|
|
||||||
|
def test_sympify_symint(self):
|
||||||
|
shape_env = ShapeEnv()
|
||||||
|
a = create_symint(shape_env, 2)
|
||||||
|
self.assertIs(sympy.sympify(a), a.node.expr)
|
||||||
|
b = create_symfloat(shape_env, 3.0)
|
||||||
|
self.assertIs(sympy.sympify(b), b.node.expr)
|
||||||
|
c = create_symbool(shape_env, True)
|
||||||
|
self.assertIs(sympy.sympify(c), c.node.expr)
|
||||||
|
|
||||||
def test_roundtrip(self):
|
def test_roundtrip(self):
|
||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||||
|
|
|
||||||
|
|
@ -507,6 +507,9 @@ class SymInt:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.node)
|
return str(self.node)
|
||||||
|
|
||||||
|
def _sympy_(self):
|
||||||
|
return self.node.expr
|
||||||
|
|
||||||
def __hash__(self) -> builtins.int:
|
def __hash__(self) -> builtins.int:
|
||||||
if self.node.is_nested_int():
|
if self.node.is_nested_int():
|
||||||
return hash(self.node.nested_int())
|
return hash(self.node.nested_int())
|
||||||
|
|
@ -615,6 +618,9 @@ class SymFloat:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.node.str()
|
return self.node.str()
|
||||||
|
|
||||||
|
def _sympy_(self):
|
||||||
|
return self.node.expr
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
if self.node.is_constant():
|
if self.node.is_constant():
|
||||||
return hash(self.node.float_())
|
return hash(self.node.float_())
|
||||||
|
|
@ -680,6 +686,9 @@ class SymBool:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.node)
|
return str(self.node)
|
||||||
|
|
||||||
|
def _sympy_(self):
|
||||||
|
return self.node.expr
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
if self.node.is_constant():
|
if self.node.is_constant():
|
||||||
return hash(self.node.bool_())
|
return hash(self.node.bool_())
|
||||||
|
|
|
||||||
|
|
@ -281,9 +281,7 @@ def convert_shape_to_inductor(
|
||||||
trivial. But for symbolic tensors, we need to map from SymIntNode into
|
trivial. But for symbolic tensors, we need to map from SymIntNode into
|
||||||
sympy.Expr.
|
sympy.Expr.
|
||||||
"""
|
"""
|
||||||
return [
|
return [sympy.sympify(i) for i in lst]
|
||||||
i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_shape_to_symint(
|
def convert_shape_to_symint(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user