mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[tf.contrib.data] Fix handling of multi-output tf.py_func() in Dataset.map().
If the `map_func` returns a list of tensors, the current code will attempt to stack it into a single tensor and raise an unintuitive error. Some multi-output ops (such as `tf.py_func()`) return lists of typically-not-stackable tensors. This change treats lists returned from `map_func` as tuples; users who were relying on this auto-stacking behavior should manually call `tf.stack()` (or `tf.convert_to_tensor()`) on the list being returned. Fixes #12396. PiperOrigin-RevId: 165731970
This commit is contained in:
parent
e6c60fb368
commit
d001b58de9
|
|
@ -549,6 +549,41 @@ class MapDatasetTest(test.TestCase):
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
def testReturnList(self):
|
||||||
|
iterator = (dataset_ops.Dataset.range(10)
|
||||||
|
.map(lambda x: [x, constant_op.constant(37.0)])
|
||||||
|
.make_initializable_iterator())
|
||||||
|
init_op = iterator.initializer
|
||||||
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(init_op)
|
||||||
|
for i in range(10):
|
||||||
|
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
def testMultiOutputPyFunc(self):
|
||||||
|
# The `tf.py_func()` op returns a list of tensors for its outputs.
|
||||||
|
def _map_fn(x_tensor):
|
||||||
|
def _map_py_func(x):
|
||||||
|
return x, np.array(37.0, dtype=np.float64)
|
||||||
|
return script_ops.py_func(
|
||||||
|
_map_py_func, [x_tensor], [dtypes.int64, dtypes.float64])
|
||||||
|
|
||||||
|
iterator = (dataset_ops.Dataset.range(10)
|
||||||
|
.map(_map_fn)
|
||||||
|
.make_initializable_iterator())
|
||||||
|
init_op = iterator.initializer
|
||||||
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(init_op)
|
||||||
|
for i in range(10):
|
||||||
|
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
|
|
@ -1991,6 +1991,19 @@ class MapDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
ret = map_func(nested_args)
|
ret = map_func(nested_args)
|
||||||
|
|
||||||
|
# If `map_func` returns a list of tensors, `nest.flatten()` and
|
||||||
|
# `ops.convert_to_tensor()` would conspire to attempt to stack
|
||||||
|
# those tensors into a single tensor, because the customized
|
||||||
|
# version of `nest.flatten()` does not recurse into lists. Since
|
||||||
|
# it is more likely that the list arose from returning the
|
||||||
|
# result of an operation (such as `tf.py_func()`) that returns a
|
||||||
|
# list of not-necessarily-stackable tensors, we treat the
|
||||||
|
# returned value is a `tuple` instead. A user wishing to pack
|
||||||
|
# the return value into a single tensor can use an explicit
|
||||||
|
# `tf.stack()` before returning.
|
||||||
|
if isinstance(ret, list):
|
||||||
|
ret = tuple(ret)
|
||||||
|
|
||||||
# Extract shape information from the returned values.
|
# Extract shape information from the returned values.
|
||||||
flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
|
flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
|
||||||
self._output_shapes = nest.pack_sequence_as(
|
self._output_shapes = nest.pack_sequence_as(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user