mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes for collections.NamedTuple (#159367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159367 Approved by: https://github.com/mlazos ghstack dependencies: #159365, #159366, #159368, #159483, #159902, #159864, #159865
This commit is contained in:
parent
87d6831b2e
commit
c6333f7dae
|
|
@ -562,6 +562,11 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||||
args = [a, b]
|
args = [a, b]
|
||||||
return sub(*args)
|
return sub(*args)
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_tuple_map(a, b):
|
||||||
|
t = tuple(map(torch.sin, [a, b]))
|
||||||
|
return t[0] + t[1]
|
||||||
|
|
||||||
def test_size_tuple_add(self):
|
def test_size_tuple_add(self):
|
||||||
def fn():
|
def fn():
|
||||||
size = torch.Size([])
|
size = torch.Size([])
|
||||||
|
|
@ -2016,6 +2021,21 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||||
tmp = mytuple(a, xy=b)
|
tmp = mytuple(a, xy=b)
|
||||||
return mytuple(tmp.x, tmp[1], tmp.xy + b)
|
return mytuple(tmp.x, tmp[1], tmp.xy + b)
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_namedtuple_replace(a, b):
|
||||||
|
mytuple = collections.namedtuple("mytuple", ["x", "y"])
|
||||||
|
t = mytuple(a, b)
|
||||||
|
t._replace(x=b)
|
||||||
|
return t.x + t.y
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_namedtuple_fields(a, b):
|
||||||
|
mytuple = collections.namedtuple("mytuple", ["x", "y"])
|
||||||
|
if mytuple._fields == ("x", "y"):
|
||||||
|
return a + b
|
||||||
|
else:
|
||||||
|
return a - b
|
||||||
|
|
||||||
class MyNamedTuple(NamedTuple):
|
class MyNamedTuple(NamedTuple):
|
||||||
first: torch.Tensor
|
first: torch.Tensor
|
||||||
second: torch.Tensor
|
second: torch.Tensor
|
||||||
|
|
|
||||||
|
|
@ -1705,16 +1705,17 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||||
if hasattr(packed, "b"):
|
if hasattr(packed, "b"):
|
||||||
b = packed.b + 1
|
b = packed.b + 1
|
||||||
c = packed[2]
|
c = packed[2]
|
||||||
return a + b + c
|
d = len(packed._fields)
|
||||||
|
return a + b + c + d
|
||||||
|
|
||||||
v1 = torch.Tensor([1])
|
v1 = torch.Tensor([1])
|
||||||
v2 = torch.Tensor([2])
|
v2 = torch.Tensor([2])
|
||||||
v3 = torch.Tensor([3])
|
v3 = torch.Tensor([3])
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch.compile(fn, backend=cnts)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 7)
|
self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 10)
|
||||||
self.assertEqual(cnts.frame_count, 1)
|
self.assertEqual(cnts.frame_count, 1)
|
||||||
self.assertEqual(cnts.op_count, 3)
|
self.assertEqual(cnts.op_count, 4)
|
||||||
|
|
||||||
def test_namedtuple3(self):
|
def test_namedtuple3(self):
|
||||||
def fn(x, packed):
|
def fn(x, packed):
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,7 @@ from .source import (
|
||||||
GradSource,
|
GradSource,
|
||||||
ListGetItemSource,
|
ListGetItemSource,
|
||||||
LocalSource,
|
LocalSource,
|
||||||
|
NamedTupleFieldsSource,
|
||||||
NNModuleSource,
|
NNModuleSource,
|
||||||
NonSerializableSetGetItemSource,
|
NonSerializableSetGetItemSource,
|
||||||
NumpyTensorSource,
|
NumpyTensorSource,
|
||||||
|
|
@ -727,6 +728,7 @@ def _get_closure_vars() -> dict[str, object]:
|
||||||
"___normalize_range_iter": normalize_range_iter,
|
"___normalize_range_iter": normalize_range_iter,
|
||||||
"___tuple_iterator_getitem": tuple_iterator_getitem,
|
"___tuple_iterator_getitem": tuple_iterator_getitem,
|
||||||
"___dataclass_fields": dataclass_fields,
|
"___dataclass_fields": dataclass_fields,
|
||||||
|
"___namedtuple_fields": lambda x: x._fields,
|
||||||
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
||||||
"__math_isnan": math.isnan,
|
"__math_isnan": math.isnan,
|
||||||
"__numpy_isnan": None if np is None else np.isnan,
|
"__numpy_isnan": None if np is None else np.isnan,
|
||||||
|
|
@ -1680,6 +1682,14 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
example_value=example_value,
|
example_value=example_value,
|
||||||
guard_manager_enum=guard_manager_enum,
|
guard_manager_enum=guard_manager_enum,
|
||||||
)
|
)
|
||||||
|
elif istype(source, NamedTupleFieldsSource):
|
||||||
|
assert base_guard_manager
|
||||||
|
out = base_guard_manager.lambda_manager(
|
||||||
|
python_lambda=lambda x: x._fields,
|
||||||
|
source=source_name,
|
||||||
|
example_value=example_value,
|
||||||
|
guard_manager_enum=guard_manager_enum,
|
||||||
|
)
|
||||||
elif istype(source, CodeSource):
|
elif istype(source, CodeSource):
|
||||||
assert base_guard_manager # to make mypy happy
|
assert base_guard_manager # to make mypy happy
|
||||||
out = base_guard_manager.code_manager(
|
out = base_guard_manager.code_manager(
|
||||||
|
|
|
||||||
|
|
@ -830,6 +830,19 @@ class TupleIteratorGetItemSource(GetItemSource):
|
||||||
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class NamedTupleFieldsSource(ChainedSource):
|
||||||
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
|
codegen(self.base)
|
||||||
|
codegen.extend_output(codegen.create_load_attrs("_fields"))
|
||||||
|
|
||||||
|
def guard_source(self) -> GuardSource:
|
||||||
|
return self.base.guard_source()
|
||||||
|
|
||||||
|
def name(self) -> str:
|
||||||
|
return f"___namedtuple_fields({self.base.name()})"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class DataclassFieldsSource(ChainedSource):
|
class DataclassFieldsSource(ChainedSource):
|
||||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
|
|
|
||||||
|
|
@ -3595,6 +3595,12 @@ class SourcelessBuilder:
|
||||||
if trace_rules.is_callable_allowed(value):
|
if trace_rules.is_callable_allowed(value):
|
||||||
tx.output.has_user_defined_allowed_in_graph = True
|
tx.output.has_user_defined_allowed_in_graph = True
|
||||||
return trace_rules.lookup_callable(value)(value)
|
return trace_rules.lookup_callable(value)(value)
|
||||||
|
elif callable(value) and UserDefinedClassVariable.is_supported_new_method(
|
||||||
|
value
|
||||||
|
):
|
||||||
|
# NamedTuple._make uses an alias of tuple.__new__
|
||||||
|
obj = trace_rules.lookup_callable(value.__self__)(value.__self__)
|
||||||
|
return GetAttrVariable(obj, "__new__")
|
||||||
elif is_function_or_wrapper(value):
|
elif is_function_or_wrapper(value):
|
||||||
return trace_rules.lookup(value)(value)
|
return trace_rules.lookup(value)(value)
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
|
|
|
||||||
|
|
@ -1373,11 +1373,11 @@ class BuiltinVariable(VariableTracker):
|
||||||
if (
|
if (
|
||||||
self.fn is tuple
|
self.fn is tuple
|
||||||
and len(args) == 2
|
and len(args) == 2
|
||||||
and args[1].has_unpack_var_sequence(tx)
|
and args[1].has_force_unpack_var_sequence(tx)
|
||||||
and not kwargs
|
and not kwargs
|
||||||
):
|
):
|
||||||
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
|
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
|
||||||
init_args = args[1].unpack_var_sequence(tx)
|
init_args = args[1].force_unpack_var_sequence(tx)
|
||||||
return variables.TupleVariable(
|
return variables.TupleVariable(
|
||||||
init_args, mutation_type=ValueMutationNew()
|
init_args, mutation_type=ValueMutationNew()
|
||||||
)
|
)
|
||||||
|
|
@ -2001,10 +2001,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
if kwargs:
|
if kwargs:
|
||||||
assert len(kwargs) == 1 and "strict" in kwargs
|
assert len(kwargs) == 1 and "strict" in kwargs
|
||||||
strict = kwargs.pop("strict", False)
|
strict = kwargs.pop("strict", False)
|
||||||
args = [
|
args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args]
|
||||||
arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg
|
|
||||||
for arg in args
|
|
||||||
]
|
|
||||||
return variables.ZipVariable(
|
return variables.ZipVariable(
|
||||||
args, strict=strict, mutation_type=ValueMutationNew()
|
args, strict=strict, mutation_type=ValueMutationNew()
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1850,10 +1850,17 @@ class CollectionsNamedTupleFunction(UserFunctionVariable):
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
constant_args = check_constant_args(args, kwargs)
|
constant_args = check_constant_args(args, kwargs)
|
||||||
if constant_args:
|
if constant_args:
|
||||||
value = self.fn(
|
try:
|
||||||
*[x.as_python_constant() for x in args],
|
value = self.fn(
|
||||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
*[x.as_python_constant() for x in args],
|
||||||
)
|
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||||
|
)
|
||||||
|
except TypeError as exc:
|
||||||
|
raise_observed_exception(
|
||||||
|
type(exc),
|
||||||
|
tx,
|
||||||
|
args=list(map(ConstantVariable.create, exc.args)),
|
||||||
|
)
|
||||||
return variables.UserDefinedClassVariable(
|
return variables.UserDefinedClassVariable(
|
||||||
value, mutation_type=ValueMutationNew()
|
value, mutation_type=ValueMutationNew()
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -345,7 +345,7 @@ class ZipVariable(IteratorVariable):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
iterables: list[Union[list[VariableTracker], VariableTracker]],
|
iterables: list[VariableTracker],
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ import torch.fx
|
||||||
from .. import graph_break_hints, polyfills, variables
|
from .. import graph_break_hints, polyfills, variables
|
||||||
from ..bytecode_transformation import create_call_function, create_instruction
|
from ..bytecode_transformation import create_call_function, create_instruction
|
||||||
from ..exc import raise_observed_exception, unimplemented_v2
|
from ..exc import raise_observed_exception, unimplemented_v2
|
||||||
from ..source import AttrSource
|
from ..source import AttrSource, NamedTupleFieldsSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
cmp_name_to_op_mapping,
|
cmp_name_to_op_mapping,
|
||||||
cmp_name_to_op_str_mapping,
|
cmp_name_to_op_str_mapping,
|
||||||
|
|
@ -1150,6 +1150,10 @@ class NamedTupleVariable(TupleVariable):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if name == "_fields":
|
||||||
|
source = NamedTupleFieldsSource(self.source) if self.source else None
|
||||||
|
return VariableTracker.build(tx, self.fields(), source=source)
|
||||||
|
|
||||||
if name in self.dynamic_attributes:
|
if name in self.dynamic_attributes:
|
||||||
return self.dynamic_attributes[name]
|
return self.dynamic_attributes[name]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2095,7 +2095,7 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable):
|
||||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||||
|
|
||||||
tx = InstructionTranslator.current_tx()
|
tx = InstructionTranslator.current_tx()
|
||||||
elems = init_args[0].unpack_var_sequence(tx)
|
elems = init_args[0].force_unpack_var_sequence(tx)
|
||||||
self._tuple_vt = variables.TupleVariable(
|
self._tuple_vt = variables.TupleVariable(
|
||||||
elems, mutation_type=ValueMutationNew()
|
elems, mutation_type=ValueMutationNew()
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user