PSv2/cfit: Remove the workaround in PS' dataset creation as the underlying issue has been fixed by 04f654efc6

PiperOrigin-RevId: 373014461
Change-Id: Idab643ad1dcd918cf138e8b83450ccf9f0e7b3dd
This commit is contained in:
Rick Chao 2021-05-10 14:36:39 -07:00 committed by TensorFlower Gardener
parent 58bce30a18
commit 1555fec94e

View File

@ -1353,16 +1353,8 @@ class _ClusterCoordinatorDataHandler(DataHandler):
def per_worker_dataset_fn(): def per_worker_dataset_fn():
def wrapped_dataset_fn(input_context):
# TODO(b/186692679): Currently we need to remove the device scope
# imposed in `distribute_datasets_from_function` lib so that any
# `StaticHashTable` is placed on the coordinator. Remove this workaround
# once resolved.
with ops.device_v2(None):
return x(input_context)
return strategy.distribute_datasets_from_function( return strategy.distribute_datasets_from_function(
wrapped_dataset_fn, options=x.input_options) x, options=x.input_options)
self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access
per_worker_dataset_fn) per_worker_dataset_fn)