mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
eager: Always run dataset iterator operations on CPU.
It has no kernels for other devices.
With an explicit "tf.device()" before invoking the kernel we ensure
that Iterator.next() functions even when placed inside a:
with tf.device("/device:GPU:0")
PiperOrigin-RevId: 171048558
This commit is contained in:
parent
3b354016e9
commit
491584ff4d
|
|
@ -23,6 +23,7 @@ import threading
|
|||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
|
||||
|
|
@ -62,20 +63,22 @@ class Iterator(object):
|
|||
raise RuntimeError(
|
||||
"{} objects only make sense when eager execution is enabled".format(
|
||||
type(self)))
|
||||
ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
self._output_types = dataset.output_types
|
||||
self._flat_output_types = nest.flatten(dataset.output_types)
|
||||
self._flat_output_shapes = nest.flatten(dataset.output_shapes)
|
||||
self._resource = gen_dataset_ops.iterator(
|
||||
container="",
|
||||
shared_name=_iterator_shared_name(),
|
||||
output_types=self._flat_output_types,
|
||||
output_shapes=self._flat_output_shapes)
|
||||
gen_dataset_ops.make_iterator(ds_variant, self._resource)
|
||||
with ops.device("/device:CPU:0"):
|
||||
ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
self._output_types = dataset.output_types
|
||||
self._flat_output_types = nest.flatten(dataset.output_types)
|
||||
self._flat_output_shapes = nest.flatten(dataset.output_shapes)
|
||||
self._resource = gen_dataset_ops.iterator(
|
||||
container="",
|
||||
shared_name=_iterator_shared_name(),
|
||||
output_types=self._flat_output_types,
|
||||
output_shapes=self._flat_output_shapes)
|
||||
gen_dataset_ops.make_iterator(ds_variant, self._resource)
|
||||
|
||||
def __del__(self):
|
||||
if self._resource is not None:
|
||||
resource_variable_ops.destroy_resource_op(self._resource)
|
||||
with ops.device("/device:CPU:0"):
|
||||
resource_variable_ops.destroy_resource_op(self._resource)
|
||||
self._resource = None
|
||||
|
||||
def __iter__(self):
|
||||
|
|
@ -87,10 +90,14 @@ class Iterator(object):
|
|||
def next(self):
|
||||
"""Return the next tf.Tensor from the dataset."""
|
||||
try:
|
||||
ret = gen_dataset_ops.iterator_get_next(
|
||||
self._resource,
|
||||
output_types=self._flat_output_types,
|
||||
output_shapes=self._flat_output_shapes)
|
||||
return nest.pack_sequence_as(self._output_types, ret)
|
||||
# TODO(ashankar): Consider removing this ops.device() contextmanager
|
||||
# and instead mimic ops placement in graphs: Operations on resource
|
||||
# handles execute on the same device as where the resource is placed.
|
||||
with ops.device("/device:CPU:0"):
|
||||
ret = gen_dataset_ops.iterator_get_next(
|
||||
self._resource,
|
||||
output_types=self._flat_output_types,
|
||||
output_shapes=self._flat_output_shapes)
|
||||
return nest.pack_sequence_as(self._output_types, ret)
|
||||
except errors.OutOfRangeError:
|
||||
raise StopIteration
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user