mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[tf.contrib.data] Fix nested dictionary handling in dataset elements.
Backports recent changes to the core version of the nest.py library. Fixes #12372. PiperOrigin-RevId: 165746517
This commit is contained in:
parent
378463ae89
commit
64e54423bb
|
|
@ -231,6 +231,16 @@ class DatasetConstructorTest(test.TestCase):
|
||||||
dtypes.int64), dataset.output_types)
|
dtypes.int64), dataset.output_types)
|
||||||
self.assertEquals(([], ([], []), []), dataset.output_shapes)
|
self.assertEquals(([], ([], []), []), dataset.output_shapes)
|
||||||
|
|
||||||
|
def testNestedDict(self):
|
||||||
|
components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(components)
|
||||||
|
self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"])
|
||||||
|
self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"])
|
||||||
|
self.assertEquals(dtypes.int32, dataset.output_types["b"])
|
||||||
|
self.assertEquals([], dataset.output_shapes["a"]["aa"])
|
||||||
|
self.assertEquals([2], dataset.output_shapes["a"]["ab"])
|
||||||
|
self.assertEquals([3], dataset.output_shapes["b"])
|
||||||
|
|
||||||
def testNonSequenceNestedStructure(self):
|
def testNonSequenceNestedStructure(self):
|
||||||
components = np.array([1, 2, 3])
|
components = np.array([1, 2, 3])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,14 @@ import six as _six
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
|
||||||
|
def _sorted(dict_):
|
||||||
|
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
|
||||||
|
try:
|
||||||
|
return sorted(_six.iterkeys(dict_))
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError("nest only supports dicts with sortable keys.")
|
||||||
|
|
||||||
|
|
||||||
def _sequence_like(instance, args):
|
def _sequence_like(instance, args):
|
||||||
"""Converts the sequence `args` to the same type as `instance`.
|
"""Converts the sequence `args` to the same type as `instance`.
|
||||||
|
|
||||||
|
|
@ -51,9 +59,13 @@ def _sequence_like(instance, args):
|
||||||
`args` with the type of `instance`.
|
`args` with the type of `instance`.
|
||||||
"""
|
"""
|
||||||
if isinstance(instance, dict):
|
if isinstance(instance, dict):
|
||||||
# This is a dict. Iterate over the keys in sorted order to make
|
# Pack dictionaries in a deterministic order by sorting the keys.
|
||||||
# this deterministic.
|
# Notice this means that we ignore the original order of `OrderedDict`
|
||||||
return {k: v for k, v in zip(sorted(instance.keys()), args)}
|
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||||||
|
# ordered and plain dicts (e.g., flattening a dict but using a
|
||||||
|
# corresponding `OrderedDict` to pack it back).
|
||||||
|
result = dict(zip(_sorted(instance), args))
|
||||||
|
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
|
||||||
elif (isinstance(instance, tuple) and
|
elif (isinstance(instance, tuple) and
|
||||||
hasattr(instance, "_fields") and
|
hasattr(instance, "_fields") and
|
||||||
isinstance(instance._fields, _collections.Sequence) and
|
isinstance(instance._fields, _collections.Sequence) and
|
||||||
|
|
@ -65,16 +77,22 @@ def _sequence_like(instance, args):
|
||||||
return type(instance)(args)
|
return type(instance)(args)
|
||||||
|
|
||||||
|
|
||||||
def _elements_of(nest):
|
def _yield_value(iterable):
|
||||||
if isinstance(nest, dict):
|
if isinstance(iterable, dict):
|
||||||
# Iterate over dict keys in sorted order to make this deterministic.
|
# Iterate through dictionaries in a deterministic order by sorting the
|
||||||
return [v for _, v in sorted(nest.items())]
|
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
||||||
|
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||||||
|
# ordered and plain dicts (e.g., flattening a dict but using a
|
||||||
|
# corresponding `OrderedDict` to pack it back).
|
||||||
|
for key in _sorted(iterable):
|
||||||
|
yield iterable[key]
|
||||||
else:
|
else:
|
||||||
return nest
|
for value in iterable:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
|
||||||
def _yield_flat_nest(nest):
|
def _yield_flat_nest(nest):
|
||||||
for n in _elements_of(nest):
|
for n in _yield_value(nest):
|
||||||
if is_sequence(n):
|
if is_sequence(n):
|
||||||
for ni in _yield_flat_nest(n):
|
for ni in _yield_flat_nest(n):
|
||||||
yield ni
|
yield ni
|
||||||
|
|
@ -132,7 +150,7 @@ def _recursive_assert_same_structure(nest1, nest2, check_types):
|
||||||
"structure has type %s, while second structure has type %s."
|
"structure has type %s, while second structure has type %s."
|
||||||
% (type_nest1, type_nest2))
|
% (type_nest1, type_nest2))
|
||||||
|
|
||||||
for n1, n2 in zip(_elements_of(nest1), _elements_of(nest2)):
|
for n1, n2 in zip(_yield_value(nest1), _yield_value(nest2)):
|
||||||
_recursive_assert_same_structure(n1, n2, check_types)
|
_recursive_assert_same_structure(n1, n2, check_types)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,7 +199,7 @@ def _packed_nest_with_indices(structure, flat, index):
|
||||||
(assuming indexing starts from `index`).
|
(assuming indexing starts from `index`).
|
||||||
"""
|
"""
|
||||||
packed = []
|
packed = []
|
||||||
for s in structure:
|
for s in _yield_value(structure):
|
||||||
if is_sequence(s):
|
if is_sequence(s):
|
||||||
new_index, child = _packed_nest_with_indices(s, flat, index)
|
new_index, child = _packed_nest_with_indices(s, flat, index)
|
||||||
packed.append(_sequence_like(s, child))
|
packed.append(_sequence_like(s, child))
|
||||||
|
|
@ -286,8 +304,8 @@ def map_structure(func, *structure, **check_types_dict):
|
||||||
def _yield_flat_up_to(shallow_tree, input_tree):
|
def _yield_flat_up_to(shallow_tree, input_tree):
|
||||||
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
||||||
if is_sequence(shallow_tree):
|
if is_sequence(shallow_tree):
|
||||||
for shallow_branch, input_branch in zip(_elements_of(shallow_tree),
|
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
||||||
_elements_of(input_tree)):
|
_yield_value(input_tree)):
|
||||||
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
||||||
yield input_leaf
|
yield input_leaf
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,73 @@ class NestTest(test.TestCase):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
|
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
|
||||||
|
|
||||||
|
def testFlattenDictOrder(self):
|
||||||
|
"""`flatten` orders dicts by key, including OrderedDicts."""
|
||||||
|
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
|
||||||
|
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
|
||||||
|
ordered_flat = nest.flatten(ordered)
|
||||||
|
plain_flat = nest.flatten(plain)
|
||||||
|
self.assertEqual([0, 1, 2, 3], ordered_flat)
|
||||||
|
self.assertEqual([0, 1, 2, 3], plain_flat)
|
||||||
|
|
||||||
|
def testPackDictOrder(self):
|
||||||
|
"""Packing orders dicts by key, including OrderedDicts."""
|
||||||
|
ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
|
||||||
|
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
|
||||||
|
seq = [0, 1, 2, 3]
|
||||||
|
ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
|
||||||
|
plain_reconstruction = nest.pack_sequence_as(plain, seq)
|
||||||
|
self.assertEqual(
|
||||||
|
collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
|
||||||
|
ordered_reconstruction)
|
||||||
|
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
|
||||||
|
|
||||||
|
def testFlattenAndPack_withDicts(self):
|
||||||
|
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
|
||||||
|
named_tuple = collections.namedtuple("A", ("b", "c"))
|
||||||
|
mess = (
|
||||||
|
"z",
|
||||||
|
named_tuple(3, 4),
|
||||||
|
{
|
||||||
|
"c": (
|
||||||
|
1,
|
||||||
|
collections.OrderedDict([
|
||||||
|
("b", 3),
|
||||||
|
("a", 2),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
"b": 5
|
||||||
|
},
|
||||||
|
17
|
||||||
|
)
|
||||||
|
|
||||||
|
flattened = nest.flatten(mess)
|
||||||
|
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
|
||||||
|
|
||||||
|
structure_of_mess = (
|
||||||
|
14,
|
||||||
|
named_tuple("a", True),
|
||||||
|
{
|
||||||
|
"c": (
|
||||||
|
0,
|
||||||
|
collections.OrderedDict([
|
||||||
|
("b", 9),
|
||||||
|
("a", 8),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
"b": 3
|
||||||
|
},
|
||||||
|
"hi everybody",
|
||||||
|
)
|
||||||
|
|
||||||
|
unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
|
||||||
|
self.assertEqual(unflattened, mess)
|
||||||
|
|
||||||
|
# Check also that the OrderedDict was created, with the correct key order.
|
||||||
|
unflattened_ordered_dict = unflattened[2]["c"][1]
|
||||||
|
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
|
||||||
|
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
|
||||||
|
|
||||||
def testIsSequence(self):
|
def testIsSequence(self):
|
||||||
self.assertFalse(nest.is_sequence("1234"))
|
self.assertFalse(nest.is_sequence("1234"))
|
||||||
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))
|
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user