[dynamo] Support namedtuple subclass (#153982)

Fixes #133762. This involves
1. support tuple subclass constructed inside compile region.
2. handle the "fake" global scope associated with NamedTuple-generated
   `__new__`.
3. handle `namedtuple._tuplegetter` more faithfully.

Differential Revision: [D75488091](https://our.internmc.facebook.com/intern/diff/D75488091)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153982
Approved by: https://github.com/jansel
ghstack dependencies: #154176
This commit is contained in:
Ryan Guo 2025-05-27 16:13:15 -07:00 committed by PyTorch MergeBot
parent 8002d22ce3
commit 7183f52675
7 changed files with 125 additions and 23 deletions

View File

@ -4842,6 +4842,68 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
self.assertTrue(ref_tup.checked)
self.assertTrue(res_tup.checked)
def test_udf_tuple_construction(self):
class MyTuple(tuple): # noqa: SLOT001
pass
def fn(x):
tup = MyTuple([1, 2, 3])
if 3 in tup:
x = torch.cos(x)
else:
x = torch.sin(x)
return x, tup
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref_x, ref_tup = fn(x)
res_x, res_tup = opt_fn(x)
self.assertEqual(ref_x, res_x)
self.assertEqual(ref_tup, res_tup)
def test_udf_tuple_construction_custom_new(self):
class MyTuple(tuple): # noqa: SLOT001
def __new__(cls, *args, **kwargs):
return super().__new__(cls, [1, 2, 3])
def fn(x):
tup = MyTuple()
if 3 in tup:
x = torch.cos(x)
else:
x = torch.sin(x)
return x, tup
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref_x, ref_tup = fn(x)
res_x, res_tup = opt_fn(x)
self.assertEqual(ref_x, res_x)
self.assertEqual(ref_tup, res_tup)
def test_udf_namedtuple(self):
class MyTuple(NamedTuple):
a: torch.Tensor
b: torch.Tensor
class PairTensor(MyTuple):
def __new__(cls, a, b):
return super().__new__(cls, a, b)
def __add__(self, other):
return PairTensor(self.a + other.a, self.b + other.b)
def fn(pair1, pair2):
return pair1 + pair2
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
pair1 = PairTensor(torch.randn(4), torch.randn(2, 8))
pair2 = PairTensor(torch.randn(1), torch.randn(2, 1))
ref = fn(pair1, pair2)
res = opt_fn(pair1, pair2)
self.assertEqual(ref.a, res.a)
self.assertEqual(ref.b, res.b)
def test_udf_tuple_reconstruction(self):
class MyTuple(tuple): # noqa: SLOT001
pass

View File

@ -11706,6 +11706,30 @@ fn
res = fn(x)
self.assertEqual(ref, res)
def test_descriptor_side_effect(self):
# This pattern (readonly descriptor but writable value in `__dict__`) is
# from scipy `_make_tuple_bunch`:
# https://github.com/scipy/scipy/blob/maintenance/1.9.x/scipy/_lib/_bunch.py#L32-L226
def fget(obj):
return obj.__dict__["field"]
class MyClass:
def __init__(self, n):
self.__dict__["field"] = n
field = property(fget)
def fn(x):
obj = MyClass(42)
return x + obj.field, obj
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref_t, ref_obj = fn(x)
res_t, res_obj = opt_fn(x)
self.assertEqual(ref_t, res_t)
self.assertEqual(ref_obj.field, res_obj.field)
def test_assert_size_stride(self):
x = torch.randn(2, 3, 4)
with self.assertRaisesRegex(

View File

@ -256,6 +256,7 @@ class SideEffects:
int.__getattribute__,
str.__getattribute__,
list.__getattribute__,
tuple.__getattribute__,
BaseException.__getattribute__,
)

View File

@ -4059,7 +4059,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
raise ReturnValueOp
def get_globals_source_and_value(self, name):
if "__name__" in self.f_globals:
# NamedTuple's `__new__` has a fake global scope that's not an actual
# module. TODO generalize the check for other non-importable cases.
# https://github.com/python/cpython/blob/8421b03b16a4852a527256cb7cdce2ab2d318548/Lib/collections/__init__.py#L441-L447
if "__name__" in self.f_globals and not self.f_globals["__name__"].startswith(
"namedtuple_"
):
module_name = self.f_globals["__name__"]
module_source = self.import_source(module_name)
if "torch_package" in module_name:

View File

@ -1380,7 +1380,7 @@ class VariableBuilder:
result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source)
return self.tx.output.side_effects.track_object_existing(value, result)
elif isinstance(value, tuple) and type(value).__new__ is tuple.__new__:
elif isinstance(value, tuple):
self.install_guards(GuardBuilder.TYPE_MATCH)
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
@ -1397,7 +1397,7 @@ class VariableBuilder:
tuple_vt = TupleVariable(
output, source=self.source, mutation_type=ValueMutationExisting()
)
result = UserDefinedTupleVariable.create(
result = UserDefinedTupleVariable(
value, tuple_vt=tuple_vt, source=self.source
)
return self.tx.output.side_effects.track_object_existing(value, result)

View File

@ -1205,20 +1205,17 @@ class BuiltinVariable(VariableTracker):
and args[1].has_unpack_var_sequence(tx)
and not kwargs
):
init_args = args[1].unpack_var_sequence(tx)
tuple_vt = variables.TupleVariable(
init_args, mutation_type=ValueMutationNew()
)
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
return tuple_vt
init_args = args[1].unpack_var_sequence(tx)
return variables.TupleVariable(
init_args, mutation_type=ValueMutationNew()
)
result = tx.output.side_effects.track_new_user_defined_object(
return tx.output.side_effects.track_new_user_defined_object(
self,
args[0],
args[1:],
)
result.set_underlying_tuple_vt(tuple_vt)
return result
if self.fn is list:
list_vt = ListVariable([], mutation_type=ValueMutationNew())

View File

@ -21,6 +21,7 @@ These classes help Dynamo track and handle arbitrary Python objects during traci
maintaining proper semantics while enabling optimizations where possible.
"""
import _collections
import builtins
import collections
import contextlib
@ -1046,7 +1047,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def _getattr_static(self, name):
subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ)
import _collections
# In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local
# has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup.
@ -1054,7 +1054,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
# as Dynamo tracing is concerned.
if not object_has_getattribute(self.value) and (
subobj is NO_SUCH_SUBOBJ # e.g., threading.local
or isinstance(subobj, _collections._tuplegetter) # namedtuple fields
or inspect.ismemberdescriptor(subobj) # e.g., __slots__
or inspect.isgetsetdescriptor(subobj) # e.g., __dict__
or self._is_c_defined_property(subobj)
@ -1213,6 +1212,14 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return variables.UserMethodVariable(
subobj.fget, self, source=source
).call_function(tx, [], {})
elif isinstance(subobj, _collections._tuplegetter):
# namedtuple fields are represented by _tuplegetter, and here we
# emulate its `__get__`, which is implemented in C.
_, (idx, _) = subobj.__reduce__()
# Don't go through the `__getitem__` method anymore, see
# https://github.com/python/cpython/blob/470941782f74288823b445120f6383914b659f23/Modules/_collectionsmodule.c#L2690
assert isinstance(self, UserDefinedTupleVariable)
return self._tuple_vt.items[idx]
elif isinstance(subobj, staticmethod):
# Safe because `staticmethod.__get__` basically won't trigger user
# code and just returns the underlying `__func__`:
@ -1696,18 +1703,24 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable):
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
self._tuple_vt = None
def set_underlying_tuple_vt(self, tuple_vt):
def __init__(self, value, tuple_vt=None, init_args=None, **kwargs):
super().__init__(value, init_args=init_args, **kwargs)
self._tuple_vt = tuple_vt
if self._tuple_vt is None:
assert self.source is None, (
"tuple_vt must be constructed by builder.py when source is present"
)
# Emulate `tuple.__new__`
# https://github.com/python/cpython/blob/3.11/Objects/tupleobject.c#L697-L710
#
# TODO this duplicates the logic in `BuiltinVariable(tuple)`
from torch._dynamo.symbolic_convert import InstructionTranslator
@staticmethod
def create(value, tuple_vt, **kwargs):
result = UserDefinedTupleVariable(value, **kwargs)
result.set_underlying_tuple_vt(tuple_vt)
return result
tx = InstructionTranslator.current_tx()
elems = init_args[0].unpack_var_sequence(tx)
self._tuple_vt = variables.TupleVariable(
elems, mutation_type=ValueMutationNew()
)
def call_method(
self,