mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tf.data] Apply options for a dataset before it is copied to another device in tf.data.experimental.copy_to_device().
PiperOrigin-RevId: 366293759 Change-Id: I2f38afd28f448cbd5a731a4a0261d74b3cc4ad0c
This commit is contained in:
parent
fb37439d64
commit
aa08128321
|
|
@ -116,6 +116,7 @@ cuda_py_test(
|
|||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/compat",
|
||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
|||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
|
|
@ -64,6 +65,24 @@ class CopyToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceHostOptimizations(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
host_dataset = host_dataset.apply(testing.assert_next(["MapAndBatch"]))
|
||||
host_dataset = host_dataset.map(lambda x: x*x).batch(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.copy_to_device("/cpu:1"))
|
||||
|
||||
with ops.device("/cpu:1"):
|
||||
iterator = dataset_ops.make_one_shot_iterator(device_dataset)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config):
|
||||
self.assertAllEqual([x*x for x in range(10)], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceInt32(self):
|
||||
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
|||
target_device: The name of the device to which elements would be copied.
|
||||
source_device: Device where input_dataset would be placed.
|
||||
"""
|
||||
self._input_dataset = input_dataset
|
||||
self._input_dataset = input_dataset._apply_options() # pylint: disable=protected-access
|
||||
self._target_device = target_device
|
||||
spec = framework_device.DeviceSpec().from_string(self._target_device)
|
||||
self._is_gpu_target = (spec.device_type == "GPU")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user