Fixes for collections.NamedTuple (#159367)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159367
Approved by: https://github.com/mlazos
ghstack dependencies: #159365, #159366, #159368, #159483, #159902, #159864, #159865
This commit is contained in:
Guilherme Leobas 2025-08-16 13:41:17 -03:00 committed by PyTorch MergeBot
parent 87d6831b2e
commit c6333f7dae
12 changed files with 74 additions and 16 deletions

View File

@ -562,6 +562,11 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
args = [a, b]
return sub(*args)
@make_test
def test_tuple_map(a, b):
t = tuple(map(torch.sin, [a, b]))
return t[0] + t[1]
def test_size_tuple_add(self):
def fn():
size = torch.Size([])
@ -2016,6 +2021,21 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
tmp = mytuple(a, xy=b)
return mytuple(tmp.x, tmp[1], tmp.xy + b)
@make_test
def test_namedtuple_replace(a, b):
mytuple = collections.namedtuple("mytuple", ["x", "y"])
t = mytuple(a, b)
t._replace(x=b)
return t.x + t.y
@make_test
def test_namedtuple_fields(a, b):
mytuple = collections.namedtuple("mytuple", ["x", "y"])
if mytuple._fields == ("x", "y"):
return a + b
else:
return a - b
class MyNamedTuple(NamedTuple):
first: torch.Tensor
second: torch.Tensor

View File

@ -1705,16 +1705,17 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
if hasattr(packed, "b"):
b = packed.b + 1
c = packed[2]
return a + b + c
d = len(packed._fields)
return a + b + c + d
v1 = torch.Tensor([1])
v2 = torch.Tensor([2])
v3 = torch.Tensor([3])
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 7)
self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 10)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 3)
self.assertEqual(cnts.op_count, 4)
def test_namedtuple3(self):
def fn(x, packed):

View File

@ -139,6 +139,7 @@ from .source import (
GradSource,
ListGetItemSource,
LocalSource,
NamedTupleFieldsSource,
NNModuleSource,
NonSerializableSetGetItemSource,
NumpyTensorSource,
@ -727,6 +728,7 @@ def _get_closure_vars() -> dict[str, object]:
"___normalize_range_iter": normalize_range_iter,
"___tuple_iterator_getitem": tuple_iterator_getitem,
"___dataclass_fields": dataclass_fields,
"___namedtuple_fields": lambda x: x._fields,
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
"__math_isnan": math.isnan,
"__numpy_isnan": None if np is None else np.isnan,
@ -1680,6 +1682,14 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, NamedTupleFieldsSource):
assert base_guard_manager
out = base_guard_manager.lambda_manager(
python_lambda=lambda x: x._fields,
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, CodeSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.code_manager(

View File

@ -830,6 +830,19 @@ class TupleIteratorGetItemSource(GetItemSource):
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
class NamedTupleFieldsSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs("_fields"))
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self) -> str:
return f"___namedtuple_fields({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class DataclassFieldsSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen") -> None:

View File

@ -3595,6 +3595,12 @@ class SourcelessBuilder:
if trace_rules.is_callable_allowed(value):
tx.output.has_user_defined_allowed_in_graph = True
return trace_rules.lookup_callable(value)(value)
elif callable(value) and UserDefinedClassVariable.is_supported_new_method(
value
):
# NamedTuple._make uses an alias of tuple.__new__
obj = trace_rules.lookup_callable(value.__self__)(value.__self__)
return GetAttrVariable(obj, "__new__")
elif is_function_or_wrapper(value):
return trace_rules.lookup(value)(value)
elif isinstance(

View File

@ -1373,11 +1373,11 @@ class BuiltinVariable(VariableTracker):
if (
self.fn is tuple
and len(args) == 2
and args[1].has_unpack_var_sequence(tx)
and args[1].has_force_unpack_var_sequence(tx)
and not kwargs
):
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
init_args = args[1].unpack_var_sequence(tx)
init_args = args[1].force_unpack_var_sequence(tx)
return variables.TupleVariable(
init_args, mutation_type=ValueMutationNew()
)
@ -2001,10 +2001,7 @@ class BuiltinVariable(VariableTracker):
if kwargs:
assert len(kwargs) == 1 and "strict" in kwargs
strict = kwargs.pop("strict", False)
args = [
arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg
for arg in args
]
args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args]
return variables.ZipVariable(
args, strict=strict, mutation_type=ValueMutationNew()
)

View File

@ -1850,10 +1850,17 @@ class CollectionsNamedTupleFunction(UserFunctionVariable):
) -> "VariableTracker":
constant_args = check_constant_args(args, kwargs)
if constant_args:
value = self.fn(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
try:
value = self.fn(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
except TypeError as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
return variables.UserDefinedClassVariable(
value, mutation_type=ValueMutationNew()
)

View File

@ -345,7 +345,7 @@ class ZipVariable(IteratorVariable):
def __init__(
self,
iterables: list[Union[list[VariableTracker], VariableTracker]],
iterables: list[VariableTracker],
strict: bool = False,
**kwargs,
) -> None:

View File

@ -27,7 +27,7 @@ import torch.fx
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import raise_observed_exception, unimplemented_v2
from ..source import AttrSource
from ..source import AttrSource, NamedTupleFieldsSource
from ..utils import (
cmp_name_to_op_mapping,
cmp_name_to_op_str_mapping,
@ -1150,6 +1150,10 @@ class NamedTupleVariable(TupleVariable):
else:
return None
if name == "_fields":
source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=source)
if name in self.dynamic_attributes:
return self.dynamic_attributes[name]

View File

@ -2095,7 +2095,7 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable):
from torch._dynamo.symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
elems = init_args[0].unpack_var_sequence(tx)
elems = init_args[0].force_unpack_var_sequence(tx)
self._tuple_vt = variables.TupleVariable(
elems, mutation_type=ValueMutationNew()
)