[tf.contrib.data] Fix handling of multi-output tf.py_func() in Dataset.map().

If the `map_func` returns a list of tensors, the current code will
attempt to stack it into a single tensor and raise an unintuitive
error. Some multi-output ops (such as `tf.py_func()`) return lists of
typically-not-stackable tensors. This change treats lists returned
from `map_func` as tuples; users who were relying on this
auto-stacking behavior should manually call `tf.stack()` (or
`tf.convert_to_tensor()`) on the list being returned.

Fixes #12396.

PiperOrigin-RevId: 165731970
This commit is contained in:
Derek Murray 2017-08-18 11:41:47 -07:00 committed by TensorFlower Gardener
parent e6c60fb368
commit d001b58de9
2 changed files with 48 additions and 0 deletions

View File

@ -549,6 +549,41 @@ class MapDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testReturnList(self):
iterator = (dataset_ops.Dataset.range(10)
.map(lambda x: [x, constant_op.constant(37.0)])
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testMultiOutputPyFunc(self):
# The `tf.py_func()` op returns a list of tensors for its outputs.
def _map_fn(x_tensor):
def _map_py_func(x):
return x, np.array(37.0, dtype=np.float64)
return script_ops.py_func(
_map_py_func, [x_tensor], [dtypes.int64, dtypes.float64])
iterator = (dataset_ops.Dataset.range(10)
.map(_map_fn)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -1991,6 +1991,19 @@ class MapDataset(Dataset):
else:
ret = map_func(nested_args)
# If `map_func` returns a list of tensors, `nest.flatten()` and
# `ops.convert_to_tensor()` would conspire to attempt to stack
# those tensors into a single tensor, because the customized
# version of `nest.flatten()` does not recurse into lists. Since
# it is more likely that the list arose from returning the
# result of an operation (such as `tf.py_func()`) that returns a
# list of not-necessarily-stackable tensors, we treat the
# returned value is a `tuple` instead. A user wishing to pack
# the return value into a single tensor can use an explicit
# `tf.stack()` before returning.
if isinstance(ret, list):
ret = tuple(ret)
# Extract shape information from the returned values.
flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
self._output_shapes = nest.pack_sequence_as(