diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 1dc42c4405d..338da69b201 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -20,7 +20,7 @@ import weakref from abc import ABC from collections import namedtuple from copy import deepcopy -from enum import Enum +from enum import Enum, IntEnum from functools import wraps from typing import Any, Dict, Iterator, List, Tuple from unittest import mock @@ -457,7 +457,7 @@ class PartialT5(torch.nn.Module): if past_key_value is not None: assert ( len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" real_seq_length += ( past_key_value[0].shape[2] if query_length is None else query_length ) @@ -4546,6 +4546,84 @@ class ReproTests(torch._dynamo.test_case.TestCase): f(*args) self.assertEqual(num_compiles, 1) + def test_issue134451(self): + class BoundingBox2DIndex(IntEnum): + _X = 0 + _Y = 1 + _HEADING = 2 + _LENGTH = 3 + _WIDTH = 4 + + @classmethod + def size(cls): + return 5 + + @classmethod + @property + def X(cls): + return cls._X + + @classmethod + @property + def Y(cls): + return cls._Y + + @classmethod + @property + def HEADING(cls): + return cls._HEADING + + @classmethod + @property + def LENGTH(cls): + return cls._LENGTH + + @classmethod + @property + def WIDTH(cls): + return cls._WIDTH + + @classmethod + @property + def POINT(cls): + # assumes X, Y have subsequent indices + return slice(cls._X, cls._Y + 1) + + @classmethod + @property + def STATE_SE2(cls): + # assumes X, Y, HEADING have subsequent indices + return slice(cls._X, cls._HEADING + 1) + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self._mlp_states = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, BoundingBox2DIndex.size()), + ) + + def forward(self, x): + agent_states = self._mlp_states(x) + agent_states[..., BoundingBox2DIndex.POINT] = ( + agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 + ) + agent_states[..., BoundingBox2DIndex.HEADING] = ( + agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi + ) + return agent_states + + model = SimpleModel().eval() + input_tensor = torch.randn(1, 10, dtype=torch.float32) + opt = torch.compile(model.eval(), backend="eager", fullgraph=True) + actual = opt(input_tensor) + try: + expected = model(input_tensor) + except Exception as e: + raise unittest.SkipTest("eager failed, requires Python>=3.12") from e + self.assertEqual(actual, expected) + def test_invalid_seq_unpack(self): def myfn(arg): (a, b) = arg diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 65de0aab6ef..de357cf8094 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -222,6 +222,8 @@ class EnumVariable(VariableTracker): unimplemented("Enum variable is constructed with non constant values") def as_proxy(self): + if isinstance(self.value, int): + return int(self.value) # convert IntEnum to a normal int return self.value def __str__(self) -> str: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 14499b4d2e4..32057c838d7 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -193,6 +193,10 @@ class UserDefinedClassVariable(UserDefinedVariable): else: return SourcelessBuilder.create(tx, func) elif isinstance(obj, classmethod): + if isinstance(obj.__func__, property): + return variables.UserFunctionVariable(obj.__func__.fget).call_function( + tx, [self], {} + ) return variables.UserMethodVariable(obj.__func__, self, source=source) elif isinstance(obj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static(dict, "fromkeys")