mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tf.contrib.data] Fix nested dictionary handling in dataset elements.
Backports recent changes to the core version of the nest.py library. Fixes #12372. PiperOrigin-RevId: 165746517
This commit is contained in:
parent
378463ae89
commit
64e54423bb
|
|
@ -231,6 +231,16 @@ class DatasetConstructorTest(test.TestCase):
|
|||
dtypes.int64), dataset.output_types)
|
||||
self.assertEquals(([], ([], []), []), dataset.output_shapes)
|
||||
|
||||
def testNestedDict(self):
|
||||
components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
|
||||
dataset = dataset_ops.Dataset.from_tensors(components)
|
||||
self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"])
|
||||
self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"])
|
||||
self.assertEquals(dtypes.int32, dataset.output_types["b"])
|
||||
self.assertEquals([], dataset.output_shapes["a"]["aa"])
|
||||
self.assertEquals([2], dataset.output_shapes["a"]["ab"])
|
||||
self.assertEquals([3], dataset.output_shapes["b"])
|
||||
|
||||
def testNonSequenceNestedStructure(self):
|
||||
components = np.array([1, 2, 3])
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,14 @@ import six as _six
|
|||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
|
||||
def _sorted(dict_):
|
||||
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
|
||||
try:
|
||||
return sorted(_six.iterkeys(dict_))
|
||||
except TypeError:
|
||||
raise TypeError("nest only supports dicts with sortable keys.")
|
||||
|
||||
|
||||
def _sequence_like(instance, args):
|
||||
"""Converts the sequence `args` to the same type as `instance`.
|
||||
|
||||
|
|
@ -51,9 +59,13 @@ def _sequence_like(instance, args):
|
|||
`args` with the type of `instance`.
|
||||
"""
|
||||
if isinstance(instance, dict):
|
||||
# This is a dict. Iterate over the keys in sorted order to make
|
||||
# this deterministic.
|
||||
return {k: v for k, v in zip(sorted(instance.keys()), args)}
|
||||
# Pack dictionaries in a deterministic order by sorting the keys.
|
||||
# Notice this means that we ignore the original order of `OrderedDict`
|
||||
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||||
# ordered and plain dicts (e.g., flattening a dict but using a
|
||||
# corresponding `OrderedDict` to pack it back).
|
||||
result = dict(zip(_sorted(instance), args))
|
||||
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
|
||||
elif (isinstance(instance, tuple) and
|
||||
hasattr(instance, "_fields") and
|
||||
isinstance(instance._fields, _collections.Sequence) and
|
||||
|
|
@ -65,16 +77,22 @@ def _sequence_like(instance, args):
|
|||
return type(instance)(args)
|
||||
|
||||
|
||||
def _elements_of(nest):
|
||||
if isinstance(nest, dict):
|
||||
# Iterate over dict keys in sorted order to make this deterministic.
|
||||
return [v for _, v in sorted(nest.items())]
|
||||
def _yield_value(iterable):
|
||||
if isinstance(iterable, dict):
|
||||
# Iterate through dictionaries in a deterministic order by sorting the
|
||||
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
||||
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||||
# ordered and plain dicts (e.g., flattening a dict but using a
|
||||
# corresponding `OrderedDict` to pack it back).
|
||||
for key in _sorted(iterable):
|
||||
yield iterable[key]
|
||||
else:
|
||||
return nest
|
||||
for value in iterable:
|
||||
yield value
|
||||
|
||||
|
||||
def _yield_flat_nest(nest):
|
||||
for n in _elements_of(nest):
|
||||
for n in _yield_value(nest):
|
||||
if is_sequence(n):
|
||||
for ni in _yield_flat_nest(n):
|
||||
yield ni
|
||||
|
|
@ -132,7 +150,7 @@ def _recursive_assert_same_structure(nest1, nest2, check_types):
|
|||
"structure has type %s, while second structure has type %s."
|
||||
% (type_nest1, type_nest2))
|
||||
|
||||
for n1, n2 in zip(_elements_of(nest1), _elements_of(nest2)):
|
||||
for n1, n2 in zip(_yield_value(nest1), _yield_value(nest2)):
|
||||
_recursive_assert_same_structure(n1, n2, check_types)
|
||||
|
||||
|
||||
|
|
@ -181,7 +199,7 @@ def _packed_nest_with_indices(structure, flat, index):
|
|||
(assuming indexing starts from `index`).
|
||||
"""
|
||||
packed = []
|
||||
for s in structure:
|
||||
for s in _yield_value(structure):
|
||||
if is_sequence(s):
|
||||
new_index, child = _packed_nest_with_indices(s, flat, index)
|
||||
packed.append(_sequence_like(s, child))
|
||||
|
|
@ -286,8 +304,8 @@ def map_structure(func, *structure, **check_types_dict):
|
|||
def _yield_flat_up_to(shallow_tree, input_tree):
|
||||
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
||||
if is_sequence(shallow_tree):
|
||||
for shallow_branch, input_branch in zip(_elements_of(shallow_tree),
|
||||
_elements_of(input_tree)):
|
||||
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
||||
_yield_value(input_tree)):
|
||||
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
||||
yield input_leaf
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -65,6 +65,73 @@ class NestTest(test.TestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
|
||||
|
||||
def testFlattenDictOrder(self):
|
||||
"""`flatten` orders dicts by key, including OrderedDicts."""
|
||||
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
|
||||
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
|
||||
ordered_flat = nest.flatten(ordered)
|
||||
plain_flat = nest.flatten(plain)
|
||||
self.assertEqual([0, 1, 2, 3], ordered_flat)
|
||||
self.assertEqual([0, 1, 2, 3], plain_flat)
|
||||
|
||||
def testPackDictOrder(self):
|
||||
"""Packing orders dicts by key, including OrderedDicts."""
|
||||
ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
|
||||
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
|
||||
seq = [0, 1, 2, 3]
|
||||
ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
|
||||
plain_reconstruction = nest.pack_sequence_as(plain, seq)
|
||||
self.assertEqual(
|
||||
collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
|
||||
ordered_reconstruction)
|
||||
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
|
||||
|
||||
def testFlattenAndPack_withDicts(self):
|
||||
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
|
||||
named_tuple = collections.namedtuple("A", ("b", "c"))
|
||||
mess = (
|
||||
"z",
|
||||
named_tuple(3, 4),
|
||||
{
|
||||
"c": (
|
||||
1,
|
||||
collections.OrderedDict([
|
||||
("b", 3),
|
||||
("a", 2),
|
||||
]),
|
||||
),
|
||||
"b": 5
|
||||
},
|
||||
17
|
||||
)
|
||||
|
||||
flattened = nest.flatten(mess)
|
||||
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
|
||||
|
||||
structure_of_mess = (
|
||||
14,
|
||||
named_tuple("a", True),
|
||||
{
|
||||
"c": (
|
||||
0,
|
||||
collections.OrderedDict([
|
||||
("b", 9),
|
||||
("a", 8),
|
||||
]),
|
||||
),
|
||||
"b": 3
|
||||
},
|
||||
"hi everybody",
|
||||
)
|
||||
|
||||
unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
|
||||
self.assertEqual(unflattened, mess)
|
||||
|
||||
# Check also that the OrderedDict was created, with the correct key order.
|
||||
unflattened_ordered_dict = unflattened[2]["c"][1]
|
||||
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
|
||||
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
|
||||
|
||||
def testIsSequence(self):
|
||||
self.assertFalse(nest.is_sequence("1234"))
|
||||
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user