[tf.contrib.data] Fix nested dictionary handling in dataset elements.

Backports recent changes to the core version of the nest.py library.

Fixes #12372.

PiperOrigin-RevId: 165746517
This commit is contained in:
Derek Murray 2017-08-18 13:41:48 -07:00 committed by TensorFlower Gardener
parent 378463ae89
commit 64e54423bb
3 changed files with 108 additions and 13 deletions

View File

@ -231,6 +231,16 @@ class DatasetConstructorTest(test.TestCase):
dtypes.int64), dataset.output_types)
self.assertEquals(([], ([], []), []), dataset.output_shapes)
def testNestedDict(self):
components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
dataset = dataset_ops.Dataset.from_tensors(components)
self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"])
self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"])
self.assertEquals(dtypes.int32, dataset.output_types["b"])
self.assertEquals([], dataset.output_shapes["a"]["aa"])
self.assertEquals([2], dataset.output_shapes["a"]["ab"])
self.assertEquals([3], dataset.output_shapes["b"])
def testNonSequenceNestedStructure(self):
components = np.array([1, 2, 3])

View File

@ -40,6 +40,14 @@ import six as _six
from tensorflow.python.util.all_util import remove_undocumented
def _sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
return sorted(_six.iterkeys(dict_))
except TypeError:
raise TypeError("nest only supports dicts with sortable keys.")
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
@ -51,9 +59,13 @@ def _sequence_like(instance, args):
`args` with the type of `instance`.
"""
if isinstance(instance, dict):
# This is a dict. Iterate over the keys in sorted order to make
# this deterministic.
return {k: v for k, v in zip(sorted(instance.keys()), args)}
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
elif (isinstance(instance, tuple) and
hasattr(instance, "_fields") and
isinstance(instance._fields, _collections.Sequence) and
@ -65,16 +77,22 @@ def _sequence_like(instance, args):
return type(instance)(args)
def _elements_of(nest):
if isinstance(nest, dict):
# Iterate over dict keys in sorted order to make this deterministic.
return [v for _, v in sorted(nest.items())]
def _yield_value(iterable):
if isinstance(iterable, dict):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable):
yield iterable[key]
else:
return nest
for value in iterable:
yield value
def _yield_flat_nest(nest):
for n in _elements_of(nest):
for n in _yield_value(nest):
if is_sequence(n):
for ni in _yield_flat_nest(n):
yield ni
@ -132,7 +150,7 @@ def _recursive_assert_same_structure(nest1, nest2, check_types):
"structure has type %s, while second structure has type %s."
% (type_nest1, type_nest2))
for n1, n2 in zip(_elements_of(nest1), _elements_of(nest2)):
for n1, n2 in zip(_yield_value(nest1), _yield_value(nest2)):
_recursive_assert_same_structure(n1, n2, check_types)
@ -181,7 +199,7 @@ def _packed_nest_with_indices(structure, flat, index):
(assuming indexing starts from `index`).
"""
packed = []
for s in structure:
for s in _yield_value(structure):
if is_sequence(s):
new_index, child = _packed_nest_with_indices(s, flat, index)
packed.append(_sequence_like(s, child))
@ -286,8 +304,8 @@ def map_structure(func, *structure, **check_types_dict):
def _yield_flat_up_to(shallow_tree, input_tree):
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
if is_sequence(shallow_tree):
for shallow_branch, input_branch in zip(_elements_of(shallow_tree),
_elements_of(input_tree)):
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
_yield_value(input_tree)):
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
yield input_leaf
else:

View File

@ -65,6 +65,73 @@ class NestTest(test.TestCase):
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
def testFlattenDictOrder(self):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
def testPackDictOrder(self):
"""Packing orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
self.assertEqual(
collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
ordered_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
def testFlattenAndPack_withDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
named_tuple = collections.namedtuple("A", ("b", "c"))
mess = (
"z",
named_tuple(3, 4),
{
"c": (
1,
collections.OrderedDict([
("b", 3),
("a", 2),
]),
),
"b": 5
},
17
)
flattened = nest.flatten(mess)
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
structure_of_mess = (
14,
named_tuple("a", True),
{
"c": (
0,
collections.OrderedDict([
("b", 9),
("a", 8),
]),
),
"b": 3
},
"hi everybody",
)
unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
self.assertEqual(unflattened, mess)
# Check also that the OrderedDict was created, with the correct key order.
unflattened_ordered_dict = unflattened[2]["c"][1]
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
def testIsSequence(self):
self.assertFalse(nest.is_sequence("1234"))
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))