mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytree] reuse flatten_fn in flatten_with_keys_fn to ensure consistency (#117656)
Reuse `flatten_fn` in `flatten_with_keys_fn` to ensure `flatten_fn` and `flatten_with_keys_fn` get the same `leaves` and `context`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117656 Approved by: https://github.com/suo
This commit is contained in:
parent
bffc8ecfb0
commit
c0940d2e93
|
|
@ -54,10 +54,10 @@ CONSTANT_NUMEL_LIMIT = 1
|
|||
# This could plausibly be handled at the Dynamo level.
|
||||
pytree.register_pytree_node(
|
||||
torch.Size,
|
||||
lambda x: (list(x), None),
|
||||
lambda xs: (list(xs), None),
|
||||
lambda xs, _: tuple(xs),
|
||||
flatten_with_keys_fn=lambda x: (
|
||||
list(zip(tuple(pytree.SequenceKey(i) for i in range(len(x))), x)),
|
||||
flatten_with_keys_fn=lambda xs: (
|
||||
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,15 +13,18 @@ def pytree_register_structseq(cls):
|
|||
return list(structseq), None
|
||||
|
||||
def structseq_flatten_with_keys(structseq):
|
||||
return list(zip(tuple(SequenceKey(i) for i in range(len(structseq))), structseq)), None
|
||||
values, context = structseq_flatten(structseq)
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
def structseq_unflatten(values, context):
|
||||
return cls(values)
|
||||
|
||||
register_pytree_node(cls,
|
||||
structseq_flatten,
|
||||
structseq_unflatten,
|
||||
flatten_with_keys_fn=structseq_flatten_with_keys)
|
||||
register_pytree_node(
|
||||
cls,
|
||||
structseq_flatten,
|
||||
structseq_unflatten,
|
||||
flatten_with_keys_fn=structseq_flatten_with_keys,
|
||||
)
|
||||
|
||||
for name in dir(return_types):
|
||||
if name.startswith('__'):
|
||||
|
|
|
|||
|
|
@ -344,7 +344,8 @@ def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
|
|||
def _tuple_flatten_with_keys(
|
||||
d: Tuple[Any, ...]
|
||||
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
||||
return list(zip(tuple(SequenceKey(i) for i in range(len(d))), d)), None
|
||||
values, context = _tuple_flatten(d)
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]:
|
||||
|
|
@ -356,7 +357,8 @@ def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
|
|||
|
||||
|
||||
def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
||||
return list(zip(tuple(SequenceKey(i) for i in range(len(d))), d)), None
|
||||
values, context = _list_flatten(d)
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]:
|
||||
|
|
@ -370,7 +372,8 @@ def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
|||
def _dict_flatten_with_keys(
|
||||
d: Dict[Any, Any]
|
||||
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
||||
return list(zip(tuple(MappingKey(k) for k in d.keys()), d.values())), d.keys()
|
||||
values, context = _dict_flatten(d)
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
||||
|
||||
def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]:
|
||||
|
|
@ -384,7 +387,11 @@ def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
|
|||
def _namedtuple_flatten_with_keys(
|
||||
d: NamedTuple,
|
||||
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
||||
return list(zip(tuple(GetAttrKey(k) for k in d._fields), d)), type(d)
|
||||
values, context = _namedtuple_flatten(d)
|
||||
return (
|
||||
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
|
||||
context,
|
||||
)
|
||||
|
||||
|
||||
def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple:
|
||||
|
|
@ -413,7 +420,8 @@ def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Co
|
|||
def _ordereddict_flatten_with_keys(
|
||||
d: GenericOrderedDict[Any, Any]
|
||||
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
||||
return list(zip(tuple(MappingKey(k) for k in d.keys()), d.values())), d.keys()
|
||||
values, context = _ordereddict_flatten(d)
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
||||
|
||||
def _ordereddict_unflatten(
|
||||
|
|
@ -435,11 +443,9 @@ def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]:
|
|||
def _defaultdict_flatten_with_keys(
|
||||
d: DefaultDict[Any, Any]
|
||||
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
||||
values, dict_context = _dict_flatten(d)
|
||||
return list(zip(tuple(MappingKey(k) for k in d.keys()), values)), [
|
||||
d.default_factory,
|
||||
dict_context,
|
||||
]
|
||||
values, context = _defaultdict_flatten(d)
|
||||
_, dict_context = context
|
||||
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
|
||||
|
||||
|
||||
def _defaultdict_unflatten(
|
||||
|
|
@ -479,14 +485,15 @@ def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
|
|||
return [default_factory, dict_context]
|
||||
|
||||
|
||||
def _deque_flatten(deq: Deque[Any]) -> Tuple[List[Any], Context]:
|
||||
return list(deq), deq.maxlen
|
||||
def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]:
|
||||
return list(d), d.maxlen
|
||||
|
||||
|
||||
def _deque_flatten_with_keys(
|
||||
deq: Deque[Any],
|
||||
d: Deque[Any],
|
||||
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
||||
return list(zip(tuple(SequenceKey(i) for i in range(len(deq))), deq)), deq.maxlen
|
||||
values, context = _deque_flatten(d)
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user