mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Support namedtuple subclass (#153982)
Fixes #133762. This involves 1. support tuple subclass constructed inside compile region. 2. handle the "fake" global scope associated with NamedTuple-generated `__new__`. 3. handle `namedtuple._tuplegetter` more faithfully. Differential Revision: [D75488091](https://our.internmc.facebook.com/intern/diff/D75488091) Pull Request resolved: https://github.com/pytorch/pytorch/pull/153982 Approved by: https://github.com/jansel ghstack dependencies: #154176
This commit is contained in:
parent
8002d22ce3
commit
7183f52675
|
|
@ -4842,6 +4842,68 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(ref_tup.checked)
|
||||
self.assertTrue(res_tup.checked)
|
||||
|
||||
def test_udf_tuple_construction(self):
|
||||
class MyTuple(tuple): # noqa: SLOT001
|
||||
pass
|
||||
|
||||
def fn(x):
|
||||
tup = MyTuple([1, 2, 3])
|
||||
if 3 in tup:
|
||||
x = torch.cos(x)
|
||||
else:
|
||||
x = torch.sin(x)
|
||||
return x, tup
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
ref_x, ref_tup = fn(x)
|
||||
res_x, res_tup = opt_fn(x)
|
||||
self.assertEqual(ref_x, res_x)
|
||||
self.assertEqual(ref_tup, res_tup)
|
||||
|
||||
def test_udf_tuple_construction_custom_new(self):
|
||||
class MyTuple(tuple): # noqa: SLOT001
|
||||
def __new__(cls, *args, **kwargs):
|
||||
return super().__new__(cls, [1, 2, 3])
|
||||
|
||||
def fn(x):
|
||||
tup = MyTuple()
|
||||
if 3 in tup:
|
||||
x = torch.cos(x)
|
||||
else:
|
||||
x = torch.sin(x)
|
||||
return x, tup
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
ref_x, ref_tup = fn(x)
|
||||
res_x, res_tup = opt_fn(x)
|
||||
self.assertEqual(ref_x, res_x)
|
||||
self.assertEqual(ref_tup, res_tup)
|
||||
|
||||
def test_udf_namedtuple(self):
|
||||
class MyTuple(NamedTuple):
|
||||
a: torch.Tensor
|
||||
b: torch.Tensor
|
||||
|
||||
class PairTensor(MyTuple):
|
||||
def __new__(cls, a, b):
|
||||
return super().__new__(cls, a, b)
|
||||
|
||||
def __add__(self, other):
|
||||
return PairTensor(self.a + other.a, self.b + other.b)
|
||||
|
||||
def fn(pair1, pair2):
|
||||
return pair1 + pair2
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
pair1 = PairTensor(torch.randn(4), torch.randn(2, 8))
|
||||
pair2 = PairTensor(torch.randn(1), torch.randn(2, 1))
|
||||
ref = fn(pair1, pair2)
|
||||
res = opt_fn(pair1, pair2)
|
||||
self.assertEqual(ref.a, res.a)
|
||||
self.assertEqual(ref.b, res.b)
|
||||
|
||||
def test_udf_tuple_reconstruction(self):
|
||||
class MyTuple(tuple): # noqa: SLOT001
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -11706,6 +11706,30 @@ fn
|
|||
res = fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_descriptor_side_effect(self):
|
||||
# This pattern (readonly descriptor but writable value in `__dict__`) is
|
||||
# from scipy `_make_tuple_bunch`:
|
||||
# https://github.com/scipy/scipy/blob/maintenance/1.9.x/scipy/_lib/_bunch.py#L32-L226
|
||||
def fget(obj):
|
||||
return obj.__dict__["field"]
|
||||
|
||||
class MyClass:
|
||||
def __init__(self, n):
|
||||
self.__dict__["field"] = n
|
||||
|
||||
field = property(fget)
|
||||
|
||||
def fn(x):
|
||||
obj = MyClass(42)
|
||||
return x + obj.field, obj
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
ref_t, ref_obj = fn(x)
|
||||
res_t, res_obj = opt_fn(x)
|
||||
self.assertEqual(ref_t, res_t)
|
||||
self.assertEqual(ref_obj.field, res_obj.field)
|
||||
|
||||
def test_assert_size_stride(self):
|
||||
x = torch.randn(2, 3, 4)
|
||||
with self.assertRaisesRegex(
|
||||
|
|
|
|||
|
|
@ -256,6 +256,7 @@ class SideEffects:
|
|||
int.__getattribute__,
|
||||
str.__getattribute__,
|
||||
list.__getattribute__,
|
||||
tuple.__getattribute__,
|
||||
BaseException.__getattribute__,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4059,7 +4059,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
raise ReturnValueOp
|
||||
|
||||
def get_globals_source_and_value(self, name):
|
||||
if "__name__" in self.f_globals:
|
||||
# NamedTuple's `__new__` has a fake global scope that's not an actual
|
||||
# module. TODO generalize the check for other non-importable cases.
|
||||
# https://github.com/python/cpython/blob/8421b03b16a4852a527256cb7cdce2ab2d318548/Lib/collections/__init__.py#L441-L447
|
||||
if "__name__" in self.f_globals and not self.f_globals["__name__"].startswith(
|
||||
"namedtuple_"
|
||||
):
|
||||
module_name = self.f_globals["__name__"]
|
||||
module_source = self.import_source(module_name)
|
||||
if "torch_package" in module_name:
|
||||
|
|
|
|||
|
|
@ -1380,7 +1380,7 @@ class VariableBuilder:
|
|||
|
||||
result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source)
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif isinstance(value, tuple) and type(value).__new__ is tuple.__new__:
|
||||
elif isinstance(value, tuple):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
|
||||
|
||||
|
|
@ -1397,7 +1397,7 @@ class VariableBuilder:
|
|||
tuple_vt = TupleVariable(
|
||||
output, source=self.source, mutation_type=ValueMutationExisting()
|
||||
)
|
||||
result = UserDefinedTupleVariable.create(
|
||||
result = UserDefinedTupleVariable(
|
||||
value, tuple_vt=tuple_vt, source=self.source
|
||||
)
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
|
|
|
|||
|
|
@ -1205,20 +1205,17 @@ class BuiltinVariable(VariableTracker):
|
|||
and args[1].has_unpack_var_sequence(tx)
|
||||
and not kwargs
|
||||
):
|
||||
init_args = args[1].unpack_var_sequence(tx)
|
||||
tuple_vt = variables.TupleVariable(
|
||||
init_args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
|
||||
return tuple_vt
|
||||
init_args = args[1].unpack_var_sequence(tx)
|
||||
return variables.TupleVariable(
|
||||
init_args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
||||
result = tx.output.side_effects.track_new_user_defined_object(
|
||||
return tx.output.side_effects.track_new_user_defined_object(
|
||||
self,
|
||||
args[0],
|
||||
args[1:],
|
||||
)
|
||||
result.set_underlying_tuple_vt(tuple_vt)
|
||||
return result
|
||||
|
||||
if self.fn is list:
|
||||
list_vt = ListVariable([], mutation_type=ValueMutationNew())
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ These classes help Dynamo track and handle arbitrary Python objects during traci
|
|||
maintaining proper semantics while enabling optimizations where possible.
|
||||
"""
|
||||
|
||||
import _collections
|
||||
import builtins
|
||||
import collections
|
||||
import contextlib
|
||||
|
|
@ -1046,7 +1047,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
|
||||
def _getattr_static(self, name):
|
||||
subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ)
|
||||
import _collections
|
||||
|
||||
# In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local
|
||||
# has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup.
|
||||
|
|
@ -1054,7 +1054,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
# as Dynamo tracing is concerned.
|
||||
if not object_has_getattribute(self.value) and (
|
||||
subobj is NO_SUCH_SUBOBJ # e.g., threading.local
|
||||
or isinstance(subobj, _collections._tuplegetter) # namedtuple fields
|
||||
or inspect.ismemberdescriptor(subobj) # e.g., __slots__
|
||||
or inspect.isgetsetdescriptor(subobj) # e.g., __dict__
|
||||
or self._is_c_defined_property(subobj)
|
||||
|
|
@ -1213,6 +1212,14 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
return variables.UserMethodVariable(
|
||||
subobj.fget, self, source=source
|
||||
).call_function(tx, [], {})
|
||||
elif isinstance(subobj, _collections._tuplegetter):
|
||||
# namedtuple fields are represented by _tuplegetter, and here we
|
||||
# emulate its `__get__`, which is implemented in C.
|
||||
_, (idx, _) = subobj.__reduce__()
|
||||
# Don't go through the `__getitem__` method anymore, see
|
||||
# https://github.com/python/cpython/blob/470941782f74288823b445120f6383914b659f23/Modules/_collectionsmodule.c#L2690
|
||||
assert isinstance(self, UserDefinedTupleVariable)
|
||||
return self._tuple_vt.items[idx]
|
||||
elif isinstance(subobj, staticmethod):
|
||||
# Safe because `staticmethod.__get__` basically won't trigger user
|
||||
# code and just returns the underlying `__func__`:
|
||||
|
|
@ -1696,18 +1703,24 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable):
|
|||
|
||||
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
|
||||
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self._tuple_vt = None
|
||||
|
||||
def set_underlying_tuple_vt(self, tuple_vt):
|
||||
def __init__(self, value, tuple_vt=None, init_args=None, **kwargs):
|
||||
super().__init__(value, init_args=init_args, **kwargs)
|
||||
self._tuple_vt = tuple_vt
|
||||
if self._tuple_vt is None:
|
||||
assert self.source is None, (
|
||||
"tuple_vt must be constructed by builder.py when source is present"
|
||||
)
|
||||
# Emulate `tuple.__new__`
|
||||
# https://github.com/python/cpython/blob/3.11/Objects/tupleobject.c#L697-L710
|
||||
#
|
||||
# TODO this duplicates the logic in `BuiltinVariable(tuple)`
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
@staticmethod
|
||||
def create(value, tuple_vt, **kwargs):
|
||||
result = UserDefinedTupleVariable(value, **kwargs)
|
||||
result.set_underlying_tuple_vt(tuple_vt)
|
||||
return result
|
||||
tx = InstructionTranslator.current_tx()
|
||||
elems = init_args[0].unpack_var_sequence(tx)
|
||||
self._tuple_vt = variables.TupleVariable(
|
||||
elems, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user