mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes for collections.NamedTuple (#159367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159367 Approved by: https://github.com/mlazos ghstack dependencies: #159365, #159366, #159368, #159483, #159902, #159864, #159865
This commit is contained in:
parent
87d6831b2e
commit
c6333f7dae
|
|
@ -562,6 +562,11 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
args = [a, b]
|
||||
return sub(*args)
|
||||
|
||||
@make_test
|
||||
def test_tuple_map(a, b):
|
||||
t = tuple(map(torch.sin, [a, b]))
|
||||
return t[0] + t[1]
|
||||
|
||||
def test_size_tuple_add(self):
|
||||
def fn():
|
||||
size = torch.Size([])
|
||||
|
|
@ -2016,6 +2021,21 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
tmp = mytuple(a, xy=b)
|
||||
return mytuple(tmp.x, tmp[1], tmp.xy + b)
|
||||
|
||||
@make_test
|
||||
def test_namedtuple_replace(a, b):
|
||||
mytuple = collections.namedtuple("mytuple", ["x", "y"])
|
||||
t = mytuple(a, b)
|
||||
t._replace(x=b)
|
||||
return t.x + t.y
|
||||
|
||||
@make_test
|
||||
def test_namedtuple_fields(a, b):
|
||||
mytuple = collections.namedtuple("mytuple", ["x", "y"])
|
||||
if mytuple._fields == ("x", "y"):
|
||||
return a + b
|
||||
else:
|
||||
return a - b
|
||||
|
||||
class MyNamedTuple(NamedTuple):
|
||||
first: torch.Tensor
|
||||
second: torch.Tensor
|
||||
|
|
|
|||
|
|
@ -1705,16 +1705,17 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
if hasattr(packed, "b"):
|
||||
b = packed.b + 1
|
||||
c = packed[2]
|
||||
return a + b + c
|
||||
d = len(packed._fields)
|
||||
return a + b + c + d
|
||||
|
||||
v1 = torch.Tensor([1])
|
||||
v2 = torch.Tensor([2])
|
||||
v3 = torch.Tensor([3])
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 7)
|
||||
self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 10)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
def test_namedtuple3(self):
|
||||
def fn(x, packed):
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ from .source import (
|
|||
GradSource,
|
||||
ListGetItemSource,
|
||||
LocalSource,
|
||||
NamedTupleFieldsSource,
|
||||
NNModuleSource,
|
||||
NonSerializableSetGetItemSource,
|
||||
NumpyTensorSource,
|
||||
|
|
@ -727,6 +728,7 @@ def _get_closure_vars() -> dict[str, object]:
|
|||
"___normalize_range_iter": normalize_range_iter,
|
||||
"___tuple_iterator_getitem": tuple_iterator_getitem,
|
||||
"___dataclass_fields": dataclass_fields,
|
||||
"___namedtuple_fields": lambda x: x._fields,
|
||||
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
||||
"__math_isnan": math.isnan,
|
||||
"__numpy_isnan": None if np is None else np.isnan,
|
||||
|
|
@ -1680,6 +1682,14 @@ class GuardBuilder(GuardBuilderBase):
|
|||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, NamedTupleFieldsSource):
|
||||
assert base_guard_manager
|
||||
out = base_guard_manager.lambda_manager(
|
||||
python_lambda=lambda x: x._fields,
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, CodeSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager.code_manager(
|
||||
|
|
|
|||
|
|
@ -830,6 +830,19 @@ class TupleIteratorGetItemSource(GetItemSource):
|
|||
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NamedTupleFieldsSource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs("_fields"))
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self) -> str:
|
||||
return f"___namedtuple_fields({self.base.name()})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DataclassFieldsSource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
|
|
|
|||
|
|
@ -3595,6 +3595,12 @@ class SourcelessBuilder:
|
|||
if trace_rules.is_callable_allowed(value):
|
||||
tx.output.has_user_defined_allowed_in_graph = True
|
||||
return trace_rules.lookup_callable(value)(value)
|
||||
elif callable(value) and UserDefinedClassVariable.is_supported_new_method(
|
||||
value
|
||||
):
|
||||
# NamedTuple._make uses an alias of tuple.__new__
|
||||
obj = trace_rules.lookup_callable(value.__self__)(value.__self__)
|
||||
return GetAttrVariable(obj, "__new__")
|
||||
elif is_function_or_wrapper(value):
|
||||
return trace_rules.lookup(value)(value)
|
||||
elif isinstance(
|
||||
|
|
|
|||
|
|
@ -1373,11 +1373,11 @@ class BuiltinVariable(VariableTracker):
|
|||
if (
|
||||
self.fn is tuple
|
||||
and len(args) == 2
|
||||
and args[1].has_unpack_var_sequence(tx)
|
||||
and args[1].has_force_unpack_var_sequence(tx)
|
||||
and not kwargs
|
||||
):
|
||||
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
|
||||
init_args = args[1].unpack_var_sequence(tx)
|
||||
init_args = args[1].force_unpack_var_sequence(tx)
|
||||
return variables.TupleVariable(
|
||||
init_args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
|
@ -2001,10 +2001,7 @@ class BuiltinVariable(VariableTracker):
|
|||
if kwargs:
|
||||
assert len(kwargs) == 1 and "strict" in kwargs
|
||||
strict = kwargs.pop("strict", False)
|
||||
args = [
|
||||
arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg
|
||||
for arg in args
|
||||
]
|
||||
args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args]
|
||||
return variables.ZipVariable(
|
||||
args, strict=strict, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1850,10 +1850,17 @@ class CollectionsNamedTupleFunction(UserFunctionVariable):
|
|||
) -> "VariableTracker":
|
||||
constant_args = check_constant_args(args, kwargs)
|
||||
if constant_args:
|
||||
value = self.fn(
|
||||
*[x.as_python_constant() for x in args],
|
||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||
)
|
||||
try:
|
||||
value = self.fn(
|
||||
*[x.as_python_constant() for x in args],
|
||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||
)
|
||||
except TypeError as exc:
|
||||
raise_observed_exception(
|
||||
type(exc),
|
||||
tx,
|
||||
args=list(map(ConstantVariable.create, exc.args)),
|
||||
)
|
||||
return variables.UserDefinedClassVariable(
|
||||
value, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -345,7 +345,7 @@ class ZipVariable(IteratorVariable):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
iterables: list[Union[list[VariableTracker], VariableTracker]],
|
||||
iterables: list[VariableTracker],
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import torch.fx
|
|||
from .. import graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import raise_observed_exception, unimplemented_v2
|
||||
from ..source import AttrSource
|
||||
from ..source import AttrSource, NamedTupleFieldsSource
|
||||
from ..utils import (
|
||||
cmp_name_to_op_mapping,
|
||||
cmp_name_to_op_str_mapping,
|
||||
|
|
@ -1150,6 +1150,10 @@ class NamedTupleVariable(TupleVariable):
|
|||
else:
|
||||
return None
|
||||
|
||||
if name == "_fields":
|
||||
source = NamedTupleFieldsSource(self.source) if self.source else None
|
||||
return VariableTracker.build(tx, self.fields(), source=source)
|
||||
|
||||
if name in self.dynamic_attributes:
|
||||
return self.dynamic_attributes[name]
|
||||
|
||||
|
|
|
|||
|
|
@ -2095,7 +2095,7 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable):
|
|||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
elems = init_args[0].unpack_var_sequence(tx)
|
||||
elems = init_args[0].force_unpack_var_sequence(tx)
|
||||
self._tuple_vt = variables.TupleVariable(
|
||||
elems, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user