mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Fix support for classmethod(property(...)) (#134968)
Fixes #134451 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134968 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
9aa22eabe7
commit
a0207c8471
|
|
@ -20,7 +20,7 @@ import weakref
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum, IntEnum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Iterator, List, Tuple
|
from typing import Any, Dict, Iterator, List, Tuple
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
@ -4546,6 +4546,84 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||||
f(*args)
|
f(*args)
|
||||||
self.assertEqual(num_compiles, 1)
|
self.assertEqual(num_compiles, 1)
|
||||||
|
|
||||||
|
def test_issue134451(self):
|
||||||
|
class BoundingBox2DIndex(IntEnum):
|
||||||
|
_X = 0
|
||||||
|
_Y = 1
|
||||||
|
_HEADING = 2
|
||||||
|
_LENGTH = 3
|
||||||
|
_WIDTH = 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def size(cls):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def X(cls):
|
||||||
|
return cls._X
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def Y(cls):
|
||||||
|
return cls._Y
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def HEADING(cls):
|
||||||
|
return cls._HEADING
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def LENGTH(cls):
|
||||||
|
return cls._LENGTH
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def WIDTH(cls):
|
||||||
|
return cls._WIDTH
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def POINT(cls):
|
||||||
|
# assumes X, Y have subsequent indices
|
||||||
|
return slice(cls._X, cls._Y + 1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def STATE_SE2(cls):
|
||||||
|
# assumes X, Y, HEADING have subsequent indices
|
||||||
|
return slice(cls._X, cls._HEADING + 1)
|
||||||
|
|
||||||
|
class SimpleModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._mlp_states = nn.Sequential(
|
||||||
|
nn.Linear(10, 20),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(20, BoundingBox2DIndex.size()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
agent_states = self._mlp_states(x)
|
||||||
|
agent_states[..., BoundingBox2DIndex.POINT] = (
|
||||||
|
agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
|
||||||
|
)
|
||||||
|
agent_states[..., BoundingBox2DIndex.HEADING] = (
|
||||||
|
agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi
|
||||||
|
)
|
||||||
|
return agent_states
|
||||||
|
|
||||||
|
model = SimpleModel().eval()
|
||||||
|
input_tensor = torch.randn(1, 10, dtype=torch.float32)
|
||||||
|
opt = torch.compile(model.eval(), backend="eager", fullgraph=True)
|
||||||
|
actual = opt(input_tensor)
|
||||||
|
try:
|
||||||
|
expected = model(input_tensor)
|
||||||
|
except Exception as e:
|
||||||
|
raise unittest.SkipTest("eager failed, requires Python>=3.12") from e
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_invalid_seq_unpack(self):
|
def test_invalid_seq_unpack(self):
|
||||||
def myfn(arg):
|
def myfn(arg):
|
||||||
(a, b) = arg
|
(a, b) = arg
|
||||||
|
|
|
||||||
|
|
@ -222,6 +222,8 @@ class EnumVariable(VariableTracker):
|
||||||
unimplemented("Enum variable is constructed with non constant values")
|
unimplemented("Enum variable is constructed with non constant values")
|
||||||
|
|
||||||
def as_proxy(self):
|
def as_proxy(self):
|
||||||
|
if isinstance(self.value, int):
|
||||||
|
return int(self.value) # convert IntEnum to a normal int
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -193,6 +193,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||||
else:
|
else:
|
||||||
return SourcelessBuilder.create(tx, func)
|
return SourcelessBuilder.create(tx, func)
|
||||||
elif isinstance(obj, classmethod):
|
elif isinstance(obj, classmethod):
|
||||||
|
if isinstance(obj.__func__, property):
|
||||||
|
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
|
||||||
|
tx, [self], {}
|
||||||
|
)
|
||||||
return variables.UserMethodVariable(obj.__func__, self, source=source)
|
return variables.UserMethodVariable(obj.__func__, self, source=source)
|
||||||
elif isinstance(obj, types.ClassMethodDescriptorType):
|
elif isinstance(obj, types.ClassMethodDescriptorType):
|
||||||
# e.g.: inspect.getattr_static(dict, "fromkeys")
|
# e.g.: inspect.getattr_static(dict, "fromkeys")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user