mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Fix support for classmethod(property(...)) (#134968)
Fixes #134451 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134968 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
9aa22eabe7
commit
a0207c8471
|
|
@ -20,7 +20,7 @@ import weakref
|
|||
from abc import ABC
|
||||
from collections import namedtuple
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from enum import Enum, IntEnum
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Iterator, List, Tuple
|
||||
from unittest import mock
|
||||
|
|
@ -457,7 +457,7 @@ class PartialT5(torch.nn.Module):
|
|||
if past_key_value is not None:
|
||||
assert (
|
||||
len(past_key_value) == 2
|
||||
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
|
||||
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
|
||||
real_seq_length += (
|
||||
past_key_value[0].shape[2] if query_length is None else query_length
|
||||
)
|
||||
|
|
@ -4546,6 +4546,84 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
f(*args)
|
||||
self.assertEqual(num_compiles, 1)
|
||||
|
||||
def test_issue134451(self):
|
||||
class BoundingBox2DIndex(IntEnum):
|
||||
_X = 0
|
||||
_Y = 1
|
||||
_HEADING = 2
|
||||
_LENGTH = 3
|
||||
_WIDTH = 4
|
||||
|
||||
@classmethod
|
||||
def size(cls):
|
||||
return 5
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def X(cls):
|
||||
return cls._X
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def Y(cls):
|
||||
return cls._Y
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def HEADING(cls):
|
||||
return cls._HEADING
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def LENGTH(cls):
|
||||
return cls._LENGTH
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def WIDTH(cls):
|
||||
return cls._WIDTH
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def POINT(cls):
|
||||
# assumes X, Y have subsequent indices
|
||||
return slice(cls._X, cls._Y + 1)
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def STATE_SE2(cls):
|
||||
# assumes X, Y, HEADING have subsequent indices
|
||||
return slice(cls._X, cls._HEADING + 1)
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._mlp_states = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, BoundingBox2DIndex.size()),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
agent_states = self._mlp_states(x)
|
||||
agent_states[..., BoundingBox2DIndex.POINT] = (
|
||||
agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
|
||||
)
|
||||
agent_states[..., BoundingBox2DIndex.HEADING] = (
|
||||
agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi
|
||||
)
|
||||
return agent_states
|
||||
|
||||
model = SimpleModel().eval()
|
||||
input_tensor = torch.randn(1, 10, dtype=torch.float32)
|
||||
opt = torch.compile(model.eval(), backend="eager", fullgraph=True)
|
||||
actual = opt(input_tensor)
|
||||
try:
|
||||
expected = model(input_tensor)
|
||||
except Exception as e:
|
||||
raise unittest.SkipTest("eager failed, requires Python>=3.12") from e
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_invalid_seq_unpack(self):
|
||||
def myfn(arg):
|
||||
(a, b) = arg
|
||||
|
|
|
|||
|
|
@ -222,6 +222,8 @@ class EnumVariable(VariableTracker):
|
|||
unimplemented("Enum variable is constructed with non constant values")
|
||||
|
||||
def as_proxy(self):
|
||||
if isinstance(self.value, int):
|
||||
return int(self.value) # convert IntEnum to a normal int
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
|
|||
|
|
@ -193,6 +193,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
else:
|
||||
return SourcelessBuilder.create(tx, func)
|
||||
elif isinstance(obj, classmethod):
|
||||
if isinstance(obj.__func__, property):
|
||||
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
|
||||
tx, [self], {}
|
||||
)
|
||||
return variables.UserMethodVariable(obj.__func__, self, source=source)
|
||||
elif isinstance(obj, types.ClassMethodDescriptorType):
|
||||
# e.g.: inspect.getattr_static(dict, "fromkeys")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user