mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Create experimental API for retrieving worker index during parameter server training.
PiperOrigin-RevId: 513253825
This commit is contained in:
parent
3177298992
commit
a5a1ced2e2
|
|
@ -62,6 +62,13 @@
|
|||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
|
||||
* `tf.distribute`
|
||||
|
||||
* Opened an experimental API,
|
||||
`tf.distribute.experimental.coordinator.get_current_worker_index`, for
|
||||
retrieving the worker index from within a worker, when using parameter
|
||||
server training with a custom training loop.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from tensorflow.python.distribute import multi_worker_test_base
|
|||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
|
||||
from tensorflow.python.distribute.coordinator import coordinator_context
|
||||
from tensorflow.python.distribute.coordinator import values as values_lib
|
||||
from tensorflow.python.eager import cancellation
|
||||
from tensorflow.python.eager import def_function
|
||||
|
|
@ -495,8 +496,64 @@ def make_coordinator(num_workers, num_ps):
|
|||
return coordinator_lib.ClusterCoordinator(strategy)
|
||||
|
||||
|
||||
class ClusterCoordinatorTest(TestCaseWithErrorReportingThread,
|
||||
parameterized.TestCase):
|
||||
class CoordinatorContextTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(CoordinatorContextTest, cls).setUpClass()
|
||||
cls.coordinator = make_coordinator(num_workers=5, num_ps=2)
|
||||
cls.strategy = cls.coordinator.strategy
|
||||
|
||||
def testWorkerIndexDatasetFn(self):
|
||||
def dataset_fn(context):
|
||||
del context
|
||||
dataset = dataset_ops.DatasetV2.range(10)
|
||||
worker_index = coordinator_context.get_current_worker_index()
|
||||
dataset = dataset.shard(
|
||||
num_shards=self.strategy._extended._num_workers,
|
||||
index=worker_index,
|
||||
)
|
||||
return dataset
|
||||
|
||||
@def_function.function
|
||||
def per_worker_dataset_fn():
|
||||
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
@def_function.function
|
||||
def train_fn(iterator):
|
||||
total = constant_op.constant(0, dtype=dtypes.int64)
|
||||
for batch in iterator:
|
||||
total += math_ops.reduce_sum(batch)
|
||||
return total
|
||||
|
||||
per_worker_dataset = self.coordinator.create_per_worker_dataset(
|
||||
per_worker_dataset_fn)
|
||||
with self.strategy.scope():
|
||||
iterator = iter(per_worker_dataset)
|
||||
ret_vals = []
|
||||
# Use private APIs to schedule in tagged queues to ensure each worker
|
||||
# executes only one closure.
|
||||
for ix in range(5):
|
||||
closure = coordinator_lib.Closure(
|
||||
train_fn,
|
||||
self.coordinator._cluster.closure_queue._cancellation_mgr,
|
||||
args=(iterator,))
|
||||
ret = closure.build_output_remote_value()
|
||||
# The queue doesn't keep track of tagged closures as inflight by
|
||||
# default, so hack around this for the test.
|
||||
self.coordinator._cluster.closure_queue._inflight_closure_count += 1
|
||||
self.coordinator._cluster.closure_queue.put(closure, tag=ix)
|
||||
ret_vals.append(ret)
|
||||
self.coordinator.join()
|
||||
|
||||
fetched_vals = [rv.fetch() for rv in ret_vals]
|
||||
expected_results = [5, 7, 9, 11, 13]
|
||||
self.assertAllClose(sorted(fetched_vals), expected_results)
|
||||
|
||||
|
||||
class ClusterCoordinatorTest(
|
||||
TestCaseWithErrorReportingThread, parameterized.TestCase
|
||||
):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,13 @@
|
|||
import contextlib
|
||||
import threading
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# There is a circular dependency between this and the `cluster_coordinator`
|
||||
# module. So we load it lazily to work around this.
|
||||
|
|
@ -61,3 +67,70 @@ class DispatchContext(object):
|
|||
|
||||
def maybe_get_remote_value(self, ret):
|
||||
return cluster_coordinator._maybe_get_remote_value(ret) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.coordinator.get_current_worker_index",
|
||||
v1=[])
|
||||
def get_current_worker_index():
|
||||
"""Returns the current worker index, when called within a worker closure.
|
||||
|
||||
Some parameter server training workloads may require the worker to know its
|
||||
index, for example for data sharding for reduced-variance training.
|
||||
|
||||
This method may be used within a `tf.function` that is executed on a worker.
|
||||
That is, either a `dataset_fn` that runs via
|
||||
`ClusterCoordinator.create_per_worker_dataset`, or any other function
|
||||
scheduled via `ClusterCoordinator.schedule`.
|
||||
|
||||
Example (sharding data by worker):
|
||||
|
||||
```python
|
||||
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
||||
cluster_resolver=...)
|
||||
coordinator = (
|
||||
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
|
||||
|
||||
def dataset_fn(context):
|
||||
dataset = tf.data.Dataset.range(10)
|
||||
worker_index = (
|
||||
tf.distribute.experimental.coordinator.get_current_worker_index()
|
||||
)
|
||||
dataset = dataset.shard(
|
||||
num_shards=num_workers,
|
||||
index=worker_index,
|
||||
)
|
||||
return dataset
|
||||
|
||||
@tf.function
|
||||
def per_worker_dataset_fn():
|
||||
return strategy.distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
per_worker_dataset = coordinator.create_per_worker_dataset(
|
||||
per_worker_dataset_fn)
|
||||
```
|
||||
|
||||
Raises:
|
||||
RuntimeError: if called from outside a `tf.function` or outside of a remote
|
||||
closure execution context (that is, on a non-worker machine).
|
||||
"""
|
||||
|
||||
msg = ("Cannot retrieve the worker index. `get_worker_idx_and_num_workers` "
|
||||
"should be called from within a tf.function being executed on a "
|
||||
"worker. This method should only be called from either a dataset_fn "
|
||||
"that is passed into `ClusterCoordinator.create_per_worker_dataset`, "
|
||||
"or a tf.function that is passed into `ClusterCoordinator.schedule`.")
|
||||
if not ops.inside_function():
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def call_time_worker_index():
|
||||
dispatch_context = get_current_dispatch_context()
|
||||
if not dispatch_context:
|
||||
raise RuntimeError(msg)
|
||||
return dispatch_context.worker_index
|
||||
|
||||
worker_index = ops.get_default_graph().capture_call_time_value(
|
||||
call_time_worker_index, tensor.TensorSpec([], dtype=dtypes.int64))
|
||||
worker_index.op._set_attr( # pylint: disable=protected-access
|
||||
"_user_specified_name",
|
||||
attr_value_pb2.AttrValue(s=compat.as_bytes("worker_index")))
|
||||
return worker_index
|
||||
|
|
|
|||
|
|
@ -12,4 +12,8 @@ tf_module {
|
|||
name: "RemoteValue"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "get_current_worker_index"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user