eager: Always run dataset iterator operations on CPU.

It has no kernels for other devices.
With an explicit "tf.device()" before invoking the kernel we ensure
that Iterator.next() functions even when placed inside a:

with tf.device("/device:GPU:0")

PiperOrigin-RevId: 171048558
This commit is contained in:
Asim Shankar 2017-10-04 12:48:27 -07:00 committed by TensorFlower Gardener
parent 3b354016e9
commit 491584ff4d

View File

@ -23,6 +23,7 @@ import threading
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
@ -62,6 +63,7 @@ class Iterator(object):
raise RuntimeError(
"{} objects only make sense when eager execution is enabled".format(
type(self)))
with ops.device("/device:CPU:0"):
ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
self._output_types = dataset.output_types
self._flat_output_types = nest.flatten(dataset.output_types)
@ -75,6 +77,7 @@ class Iterator(object):
def __del__(self):
if self._resource is not None:
with ops.device("/device:CPU:0"):
resource_variable_ops.destroy_resource_op(self._resource)
self._resource = None
@ -87,6 +90,10 @@ class Iterator(object):
def next(self):
"""Return the next tf.Tensor from the dataset."""
try:
# TODO(ashankar): Consider removing this ops.device() contextmanager
# and instead mimic ops placement in graphs: Operations on resource
# handles execute on the same device as where the resource is placed.
with ops.device("/device:CPU:0"):
ret = gen_dataset_ops.iterator_get_next(
self._resource,
output_types=self._flat_output_types,