Add__int__ and __float__ methods to _sympy.functions.Identity (#155873)

Fixes #155688

Root Cause:
in [`torch/_inductor/index_propagation.py`](f151b20123/torch/_inductor/index_propagation.py (L57-L68))
When creating a `TypedExpr` from an `Identity` (a `torch.utils._sympy.functions.Identity`, not a `sympy.matrices.expressions.Identity `) and the inner value of the identity, `Identity.args[0]`, is any torch int type, the `TypedExpr.__post_init__` method tries to cast the Identity object to a python `int`.  This is where to `TypeError` from the issue was raised, because Identity does not know how to cast to an `int`.

Fix:
Define `__int__` method for `torch.utils._sympy.functions.Identity`.
wlog for `float`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155873
Approved by: https://github.com/williamwen42
This commit is contained in:
Austin Wahle 2025-06-15 04:24:37 +00:00 committed by PyTorch MergeBot
parent 6ebe9a4f47
commit 517d2995e0
3 changed files with 77 additions and 0 deletions

View File

@ -2124,6 +2124,49 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
out_compiled = torch.compile(interpolate_chunked)(x)
self.assertEqual(out_eager, out_compiled)
def test_max_autotune_nograd(self):
"""
https://github.com/pytorch/pytorch/issues/155688
Smallest repro for max-autotune not working with no_grad
Before adding __int__ function to torch.utils._sympy.functions.Identity,
running the max_autotune mode would raise an error:
TypeError: Expected a number but got Identity
"""
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_layers = nn.ModuleList(
[
nn.Linear(4, 1, bias=True),
nn.Linear(5, 1, bias=True),
nn.Linear(6, 1, bias=True),
nn.Linear(7, 1, bias=True),
nn.Linear(8, 1, bias=True),
]
)
def forward(self, x):
for layer in self.linear_layers:
x2 = layer(x)
x2 = F.relu(x2)
x = torch.cat((x, x2), dim=1)
return x
model = ToyModel().to("cuda")
input_tensor = torch.randn((2, 4)).to("cuda")
compile_default = torch.compile(model, mode="default")
compile_max_autotune = torch.compile(model, mode="max-autotune")
with torch.no_grad():
default_output = compile_default(input_tensor)
max_autotune_output = compile_max_autotune(input_tensor)
self.assertEqual(default_output, max_autotune_output)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -37,6 +37,7 @@ from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRanges
from torch._inductor.bounds import ValueRangeAnalysis
from torch._inductor.index_propagation import TypedExpr
UNARY_OPS = [
@ -968,6 +969,33 @@ class TestIdentity(TestCase):
self.assertEqual(expanded.count(Identity), 0)
self.assertEqual(expanded, arg)
def test_cast_identity_int(self):
num = 1
expr = Identity(num)
self.assertEqual(num, int(expr))
def test_cast_identity_float(self):
num = 1.1
expr = Identity(num)
self.assertEqual(num, float(expr))
def test_cast_identity_illegal(self):
sym = Identity(sympy.Symbol("x"))
self.assertRaises(TypeError, int, sym)
self.assertRaises(TypeError, float, sym)
tup = (0, 1, 2)
tup_I = Identity(tup)
self.assertRaises(TypeError, int, tup_I)
self.assertRaises(TypeError, float, tup_I)
class TestTypedExpr(TestCase):
def test_typed_expr(self):
I = Identity(1)
typed_I = TypedExpr(I, torch.int32)
self.assertEqual(typed_I.expr, 1)
instantiate_parametrized_tests(TestValueRanges)
instantiate_parametrized_tests(TestSympyInterp)
instantiate_parametrized_tests(TestSympySolve)

View File

@ -1300,6 +1300,12 @@ class Identity(sympy.Function):
# Removes the identity op.
return self.args[0]
def __int__(self) -> int:
return int(self.args[0])
def __float__(self) -> float:
return float(self.args[0])
def make_opaque_unary_fn(name):
class OpaqueUnaryFn(sympy.Function):