mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add__int__ and __float__ methods to _sympy.functions.Identity (#155873)
Fixes #155688
Root Cause:
in [`torch/_inductor/index_propagation.py`](f151b20123/torch/_inductor/index_propagation.py (L57-L68))
When creating a `TypedExpr` from an `Identity` (a `torch.utils._sympy.functions.Identity`, not a `sympy.matrices.expressions.Identity `) and the inner value of the identity, `Identity.args[0]`, is any torch int type, the `TypedExpr.__post_init__` method tries to cast the Identity object to a python `int`. This is where to `TypeError` from the issue was raised, because Identity does not know how to cast to an `int`.
Fix:
Define `__int__` method for `torch.utils._sympy.functions.Identity`.
wlog for `float`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155873
Approved by: https://github.com/williamwen42
This commit is contained in:
parent
6ebe9a4f47
commit
517d2995e0
|
|
@ -2124,6 +2124,49 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
|
|||
out_compiled = torch.compile(interpolate_chunked)(x)
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
|
||||
def test_max_autotune_nograd(self):
|
||||
"""
|
||||
https://github.com/pytorch/pytorch/issues/155688
|
||||
Smallest repro for max-autotune not working with no_grad
|
||||
Before adding __int__ function to torch.utils._sympy.functions.Identity,
|
||||
running the max_autotune mode would raise an error:
|
||||
TypeError: Expected a number but got Identity
|
||||
"""
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.linear_layers = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(4, 1, bias=True),
|
||||
nn.Linear(5, 1, bias=True),
|
||||
nn.Linear(6, 1, bias=True),
|
||||
nn.Linear(7, 1, bias=True),
|
||||
nn.Linear(8, 1, bias=True),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.linear_layers:
|
||||
x2 = layer(x)
|
||||
x2 = F.relu(x2)
|
||||
x = torch.cat((x, x2), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
model = ToyModel().to("cuda")
|
||||
input_tensor = torch.randn((2, 4)).to("cuda")
|
||||
|
||||
compile_default = torch.compile(model, mode="default")
|
||||
compile_max_autotune = torch.compile(model, mode="max-autotune")
|
||||
|
||||
with torch.no_grad():
|
||||
default_output = compile_default(input_tensor)
|
||||
max_autotune_output = compile_max_autotune(input_tensor)
|
||||
|
||||
self.assertEqual(default_output, max_autotune_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from torch.utils._sympy.singleton_int import SingletonInt
|
|||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch._inductor.bounds import ValueRangeAnalysis
|
||||
from torch._inductor.index_propagation import TypedExpr
|
||||
|
||||
|
||||
UNARY_OPS = [
|
||||
|
|
@ -968,6 +969,33 @@ class TestIdentity(TestCase):
|
|||
self.assertEqual(expanded.count(Identity), 0)
|
||||
self.assertEqual(expanded, arg)
|
||||
|
||||
def test_cast_identity_int(self):
|
||||
num = 1
|
||||
expr = Identity(num)
|
||||
self.assertEqual(num, int(expr))
|
||||
|
||||
def test_cast_identity_float(self):
|
||||
num = 1.1
|
||||
expr = Identity(num)
|
||||
self.assertEqual(num, float(expr))
|
||||
|
||||
def test_cast_identity_illegal(self):
|
||||
sym = Identity(sympy.Symbol("x"))
|
||||
self.assertRaises(TypeError, int, sym)
|
||||
self.assertRaises(TypeError, float, sym)
|
||||
|
||||
tup = (0, 1, 2)
|
||||
tup_I = Identity(tup)
|
||||
self.assertRaises(TypeError, int, tup_I)
|
||||
self.assertRaises(TypeError, float, tup_I)
|
||||
|
||||
class TestTypedExpr(TestCase):
|
||||
def test_typed_expr(self):
|
||||
I = Identity(1)
|
||||
typed_I = TypedExpr(I, torch.int32)
|
||||
self.assertEqual(typed_I.expr, 1)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestValueRanges)
|
||||
instantiate_parametrized_tests(TestSympyInterp)
|
||||
instantiate_parametrized_tests(TestSympySolve)
|
||||
|
|
|
|||
|
|
@ -1300,6 +1300,12 @@ class Identity(sympy.Function):
|
|||
# Removes the identity op.
|
||||
return self.args[0]
|
||||
|
||||
def __int__(self) -> int:
|
||||
return int(self.args[0])
|
||||
|
||||
def __float__(self) -> float:
|
||||
return float(self.args[0])
|
||||
|
||||
|
||||
def make_opaque_unary_fn(name):
|
||||
class OpaqueUnaryFn(sympy.Function):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user