[dynamo] Fix support for classmethod(property(...)) (#134968)

Fixes #134451

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134968
Approved by: https://github.com/yanboliang
This commit is contained in:
Jason Ansel 2024-09-16 21:14:36 -07:00 committed by PyTorch MergeBot
parent 9aa22eabe7
commit a0207c8471
3 changed files with 86 additions and 2 deletions

View File

@ -20,7 +20,7 @@ import weakref
from abc import ABC
from collections import namedtuple
from copy import deepcopy
from enum import Enum
from enum import Enum, IntEnum
from functools import wraps
from typing import Any, Dict, Iterator, List, Tuple
from unittest import mock
@ -457,7 +457,7 @@ class PartialT5(torch.nn.Module):
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
real_seq_length += (
past_key_value[0].shape[2] if query_length is None else query_length
)
@ -4546,6 +4546,84 @@ class ReproTests(torch._dynamo.test_case.TestCase):
f(*args)
self.assertEqual(num_compiles, 1)
def test_issue134451(self):
class BoundingBox2DIndex(IntEnum):
_X = 0
_Y = 1
_HEADING = 2
_LENGTH = 3
_WIDTH = 4
@classmethod
def size(cls):
return 5
@classmethod
@property
def X(cls):
return cls._X
@classmethod
@property
def Y(cls):
return cls._Y
@classmethod
@property
def HEADING(cls):
return cls._HEADING
@classmethod
@property
def LENGTH(cls):
return cls._LENGTH
@classmethod
@property
def WIDTH(cls):
return cls._WIDTH
@classmethod
@property
def POINT(cls):
# assumes X, Y have subsequent indices
return slice(cls._X, cls._Y + 1)
@classmethod
@property
def STATE_SE2(cls):
# assumes X, Y, HEADING have subsequent indices
return slice(cls._X, cls._HEADING + 1)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self._mlp_states = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, BoundingBox2DIndex.size()),
)
def forward(self, x):
agent_states = self._mlp_states(x)
agent_states[..., BoundingBox2DIndex.POINT] = (
agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
)
agent_states[..., BoundingBox2DIndex.HEADING] = (
agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi
)
return agent_states
model = SimpleModel().eval()
input_tensor = torch.randn(1, 10, dtype=torch.float32)
opt = torch.compile(model.eval(), backend="eager", fullgraph=True)
actual = opt(input_tensor)
try:
expected = model(input_tensor)
except Exception as e:
raise unittest.SkipTest("eager failed, requires Python>=3.12") from e
self.assertEqual(actual, expected)
def test_invalid_seq_unpack(self):
def myfn(arg):
(a, b) = arg

View File

@ -222,6 +222,8 @@ class EnumVariable(VariableTracker):
unimplemented("Enum variable is constructed with non constant values")
def as_proxy(self):
if isinstance(self.value, int):
return int(self.value) # convert IntEnum to a normal int
return self.value
def __str__(self) -> str:

View File

@ -193,6 +193,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
else:
return SourcelessBuilder.create(tx, func)
elif isinstance(obj, classmethod):
if isinstance(obj.__func__, property):
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
tx, [self], {}
)
return variables.UserMethodVariable(obj.__func__, self, source=source)
elif isinstance(obj, types.ClassMethodDescriptorType):
# e.g.: inspect.getattr_static(dict, "fromkeys")