[tf.data] Apply options for a dataset before it is copied to another device in tf.data.experimental.copy_to_device().

PiperOrigin-RevId: 366293759
Change-Id: I2f38afd28f448cbd5a731a4a0261d74b3cc4ad0c
This commit is contained in:
Jiri Simsa 2021-04-01 11:28:16 -07:00 committed by Geeta Chavan
parent fb37439d64
commit aa08128321
3 changed files with 21 additions and 1 deletions

View File

@ -116,6 +116,7 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python/compat",
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",

View File

@ -21,6 +21,7 @@ from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
@ -64,6 +65,24 @@ class CopyToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceHostOptimizations(self):
host_dataset = dataset_ops.Dataset.range(10)
host_dataset = host_dataset.apply(testing.assert_next(["MapAndBatch"]))
host_dataset = host_dataset.map(lambda x: x*x).batch(10)
device_dataset = host_dataset.apply(
prefetching_ops.copy_to_device("/cpu:1"))
with ops.device("/cpu:1"):
iterator = dataset_ops.make_one_shot_iterator(device_dataset)
next_element = iterator.get_next()
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config):
self.assertAllEqual([x*x for x in range(10)], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceInt32(self):
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])

View File

@ -94,7 +94,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
target_device: The name of the device to which elements would be copied.
source_device: Device where input_dataset would be placed.
"""
self._input_dataset = input_dataset
self._input_dataset = input_dataset._apply_options() # pylint: disable=protected-access
self._target_device = target_device
spec = framework_device.DeviceSpec().from_string(self._target_device)
self._is_gpu_target = (spec.device_type == "GPU")