Create experimental API for retrieving worker index during parameter server training.

PiperOrigin-RevId: 513253825
This commit is contained in:
James Mullenbach 2023-03-01 09:11:19 -08:00 committed by TensorFlower Gardener
parent 3177298992
commit a5a1ced2e2
4 changed files with 143 additions and 2 deletions

View File

@ -62,6 +62,13 @@
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* `tf.distribute`
* Opened an experimental API,
`tf.distribute.experimental.coordinator.get_current_worker_index`, for
retrieving the worker index from within a worker, when using parameter
server training with a custom training loop.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:

View File

@ -35,6 +35,7 @@ from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.distribute.coordinator import coordinator_context
from tensorflow.python.distribute.coordinator import values as values_lib
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import def_function
@ -495,8 +496,64 @@ def make_coordinator(num_workers, num_ps):
return coordinator_lib.ClusterCoordinator(strategy)
class ClusterCoordinatorTest(TestCaseWithErrorReportingThread,
parameterized.TestCase):
class CoordinatorContextTest(test.TestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
super(CoordinatorContextTest, cls).setUpClass()
cls.coordinator = make_coordinator(num_workers=5, num_ps=2)
cls.strategy = cls.coordinator.strategy
def testWorkerIndexDatasetFn(self):
def dataset_fn(context):
del context
dataset = dataset_ops.DatasetV2.range(10)
worker_index = coordinator_context.get_current_worker_index()
dataset = dataset.shard(
num_shards=self.strategy._extended._num_workers,
index=worker_index,
)
return dataset
@def_function.function
def per_worker_dataset_fn():
return self.strategy.distribute_datasets_from_function(dataset_fn)
@def_function.function
def train_fn(iterator):
total = constant_op.constant(0, dtype=dtypes.int64)
for batch in iterator:
total += math_ops.reduce_sum(batch)
return total
per_worker_dataset = self.coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
with self.strategy.scope():
iterator = iter(per_worker_dataset)
ret_vals = []
# Use private APIs to schedule in tagged queues to ensure each worker
# executes only one closure.
for ix in range(5):
closure = coordinator_lib.Closure(
train_fn,
self.coordinator._cluster.closure_queue._cancellation_mgr,
args=(iterator,))
ret = closure.build_output_remote_value()
# The queue doesn't keep track of tagged closures as inflight by
# default, so hack around this for the test.
self.coordinator._cluster.closure_queue._inflight_closure_count += 1
self.coordinator._cluster.closure_queue.put(closure, tag=ix)
ret_vals.append(ret)
self.coordinator.join()
fetched_vals = [rv.fetch() for rv in ret_vals]
expected_results = [5, 7, 9, 11, 13]
self.assertAllClose(sorted(fetched_vals), expected_results)
class ClusterCoordinatorTest(
TestCaseWithErrorReportingThread, parameterized.TestCase
):
@classmethod
def setUpClass(cls):

View File

@ -17,7 +17,13 @@
import contextlib
import threading
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor
from tensorflow.python.util import compat
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# There is a circular dependency between this and the `cluster_coordinator`
# module. So we load it lazily to work around this.
@ -61,3 +67,70 @@ class DispatchContext(object):
def maybe_get_remote_value(self, ret):
return cluster_coordinator._maybe_get_remote_value(ret) # pylint: disable=protected-access
@tf_export("distribute.experimental.coordinator.get_current_worker_index",
v1=[])
def get_current_worker_index():
"""Returns the current worker index, when called within a worker closure.
Some parameter server training workloads may require the worker to know its
index, for example for data sharding for reduced-variance training.
This method may be used within a `tf.function` that is executed on a worker.
That is, either a `dataset_fn` that runs via
`ClusterCoordinator.create_per_worker_dataset`, or any other function
scheduled via `ClusterCoordinator.schedule`.
Example (sharding data by worker):
```python
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
def dataset_fn(context):
dataset = tf.data.Dataset.range(10)
worker_index = (
tf.distribute.experimental.coordinator.get_current_worker_index()
)
dataset = dataset.shard(
num_shards=num_workers,
index=worker_index,
)
return dataset
@tf.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
```
Raises:
RuntimeError: if called from outside a `tf.function` or outside of a remote
closure execution context (that is, on a non-worker machine).
"""
msg = ("Cannot retrieve the worker index. `get_worker_idx_and_num_workers` "
"should be called from within a tf.function being executed on a "
"worker. This method should only be called from either a dataset_fn "
"that is passed into `ClusterCoordinator.create_per_worker_dataset`, "
"or a tf.function that is passed into `ClusterCoordinator.schedule`.")
if not ops.inside_function():
raise RuntimeError(msg)
def call_time_worker_index():
dispatch_context = get_current_dispatch_context()
if not dispatch_context:
raise RuntimeError(msg)
return dispatch_context.worker_index
worker_index = ops.get_default_graph().capture_call_time_value(
call_time_worker_index, tensor.TensorSpec([], dtype=dtypes.int64))
worker_index.op._set_attr( # pylint: disable=protected-access
"_user_specified_name",
attr_value_pb2.AttrValue(s=compat.as_bytes("worker_index")))
return worker_index

View File

@ -12,4 +12,8 @@ tf_module {
name: "RemoteValue"
mtype: "<type \'type\'>"
}
member_method {
name: "get_current_worker_index"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}