mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tf.contrib.data] Fix handling of multi-output tf.py_func() in Dataset.map().
If the `map_func` returns a list of tensors, the current code will attempt to stack it into a single tensor and raise an unintuitive error. Some multi-output ops (such as `tf.py_func()`) return lists of typically-not-stackable tensors. This change treats lists returned from `map_func` as tuples; users who were relying on this auto-stacking behavior should manually call `tf.stack()` (or `tf.convert_to_tensor()`) on the list being returned. Fixes #12396. PiperOrigin-RevId: 165731970
This commit is contained in:
parent
e6c60fb368
commit
d001b58de9
|
|
@ -549,6 +549,41 @@ class MapDatasetTest(test.TestCase):
|
|||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testReturnList(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(lambda x: [x, constant_op.constant(37.0)])
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMultiOutputPyFunc(self):
|
||||
# The `tf.py_func()` op returns a list of tensors for its outputs.
|
||||
def _map_fn(x_tensor):
|
||||
def _map_py_func(x):
|
||||
return x, np.array(37.0, dtype=np.float64)
|
||||
return script_ops.py_func(
|
||||
_map_py_func, [x_tensor], [dtypes.int64, dtypes.float64])
|
||||
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(_map_fn)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -1991,6 +1991,19 @@ class MapDataset(Dataset):
|
|||
else:
|
||||
ret = map_func(nested_args)
|
||||
|
||||
# If `map_func` returns a list of tensors, `nest.flatten()` and
|
||||
# `ops.convert_to_tensor()` would conspire to attempt to stack
|
||||
# those tensors into a single tensor, because the customized
|
||||
# version of `nest.flatten()` does not recurse into lists. Since
|
||||
# it is more likely that the list arose from returning the
|
||||
# result of an operation (such as `tf.py_func()`) that returns a
|
||||
# list of not-necessarily-stackable tensors, we treat the
|
||||
# returned value is a `tuple` instead. A user wishing to pack
|
||||
# the return value into a single tensor can use an explicit
|
||||
# `tf.stack()` before returning.
|
||||
if isinstance(ret, list):
|
||||
ret = tuple(ret)
|
||||
|
||||
# Extract shape information from the returned values.
|
||||
flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
|
||||
self._output_shapes = nest.pack_sequence_as(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user