[pytree] reuse flatten_fn in flatten_with_keys_fn to ensure consistency (#117656)

Reuse `flatten_fn` in `flatten_with_keys_fn` to ensure `flatten_fn` and `flatten_with_keys_fn` get the same `leaves` and `context`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117656
Approved by: https://github.com/suo
This commit is contained in:
Xuehai Pan 2024-01-17 20:38:45 +00:00 committed by PyTorch MergeBot
parent bffc8ecfb0
commit c0940d2e93
3 changed files with 32 additions and 22 deletions

View File

@ -54,10 +54,10 @@ CONSTANT_NUMEL_LIMIT = 1
# This could plausibly be handled at the Dynamo level.
pytree.register_pytree_node(
torch.Size,
lambda x: (list(x), None),
lambda xs: (list(xs), None),
lambda xs, _: tuple(xs),
flatten_with_keys_fn=lambda x: (
list(zip(tuple(pytree.SequenceKey(i) for i in range(len(x))), x)),
flatten_with_keys_fn=lambda xs: (
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
None,
),
)

View File

@ -13,15 +13,18 @@ def pytree_register_structseq(cls):
return list(structseq), None
def structseq_flatten_with_keys(structseq):
return list(zip(tuple(SequenceKey(i) for i in range(len(structseq))), structseq)), None
values, context = structseq_flatten(structseq)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
def structseq_unflatten(values, context):
return cls(values)
register_pytree_node(cls,
structseq_flatten,
structseq_unflatten,
flatten_with_keys_fn=structseq_flatten_with_keys)
register_pytree_node(
cls,
structseq_flatten,
structseq_unflatten,
flatten_with_keys_fn=structseq_flatten_with_keys,
)
for name in dir(return_types):
if name.startswith('__'):

View File

@ -344,7 +344,8 @@ def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
def _tuple_flatten_with_keys(
d: Tuple[Any, ...]
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
return list(zip(tuple(SequenceKey(i) for i in range(len(d))), d)), None
values, context = _tuple_flatten(d)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]:
@ -356,7 +357,8 @@ def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
return list(zip(tuple(SequenceKey(i) for i in range(len(d))), d)), None
values, context = _list_flatten(d)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]:
@ -370,7 +372,8 @@ def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
def _dict_flatten_with_keys(
d: Dict[Any, Any]
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
return list(zip(tuple(MappingKey(k) for k in d.keys()), d.values())), d.keys()
values, context = _dict_flatten(d)
return [(MappingKey(k), v) for k, v in zip(context, values)], context
def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]:
@ -384,7 +387,11 @@ def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
def _namedtuple_flatten_with_keys(
d: NamedTuple,
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
return list(zip(tuple(GetAttrKey(k) for k in d._fields), d)), type(d)
values, context = _namedtuple_flatten(d)
return (
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
context,
)
def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple:
@ -413,7 +420,8 @@ def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Co
def _ordereddict_flatten_with_keys(
d: GenericOrderedDict[Any, Any]
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
return list(zip(tuple(MappingKey(k) for k in d.keys()), d.values())), d.keys()
values, context = _ordereddict_flatten(d)
return [(MappingKey(k), v) for k, v in zip(context, values)], context
def _ordereddict_unflatten(
@ -435,11 +443,9 @@ def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]:
def _defaultdict_flatten_with_keys(
d: DefaultDict[Any, Any]
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
values, dict_context = _dict_flatten(d)
return list(zip(tuple(MappingKey(k) for k in d.keys()), values)), [
d.default_factory,
dict_context,
]
values, context = _defaultdict_flatten(d)
_, dict_context = context
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
def _defaultdict_unflatten(
@ -479,14 +485,15 @@ def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
return [default_factory, dict_context]
def _deque_flatten(deq: Deque[Any]) -> Tuple[List[Any], Context]:
return list(deq), deq.maxlen
def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]:
return list(d), d.maxlen
def _deque_flatten_with_keys(
deq: Deque[Any],
d: Deque[Any],
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
return list(zip(tuple(SequenceKey(i) for i in range(len(deq))), deq)), deq.maxlen
values, context = _deque_flatten(d)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]: