mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Implement compiled MWMS with XLA on GPU.
Required a new runner with _noshare combination, and a test in Keras for full coverage. PiperOrigin-RevId: 452006969
This commit is contained in:
parent
a96533cc05
commit
edb19a71c7
|
|
@ -72,6 +72,9 @@
|
|||
not warnings will be printed when operations in the provided `fn` fall
|
||||
back to a while loop.
|
||||
|
||||
* XLA:
|
||||
* MWMS is now compilable with XLA.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* `tf.keras`:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
|
|
@ -240,6 +242,10 @@ Status ResolveDeviceAssignment(
|
|||
}
|
||||
gpu_options.set_gpu_global_device_ids(global_device_ids);
|
||||
}
|
||||
const std::string& communicator_key =
|
||||
params->group.runtime_details.communicator_key;
|
||||
gpu_options.set_nccl_unique_id_callback(
|
||||
[=](const xla::gpu::NcclCliqueKey& key) { return communicator_key; });
|
||||
run_options.set_device_assignment(&device_assignment);
|
||||
run_options.set_gpu_executable_run_options(&gpu_options);
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -40,10 +40,12 @@ from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
|||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import tpu_strategy_util
|
||||
from tensorflow.python.training.tracking import base
|
||||
|
|
@ -297,6 +299,10 @@ class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
|
|||
else 0)
|
||||
|
||||
|
||||
def _is_gpu_device(device):
|
||||
return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
|
||||
|
||||
|
||||
class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
"""Implementation of CollectiveAllReduceStrategy."""
|
||||
|
||||
|
|
@ -333,9 +339,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||
cross_device_ops_lib.CollectiveAllReduce)
|
||||
|
||||
def _use_merge_call(self):
|
||||
logging.log_first_n(logging.WARN, "XLA is not supported for multi-worker "
|
||||
"strategy.", 1)
|
||||
return True
|
||||
# We currently only disable merge_call when XLA is used to compile the `fn`
|
||||
# passed to `strategy.run` and all devices are GPU.
|
||||
return not control_flow_util.GraphOrParentsInXlaContext(
|
||||
ops.get_default_graph()) or not all(
|
||||
[_is_gpu_device(d) for d in self._devices])
|
||||
|
||||
def _initialize_strategy(self, cluster_resolver):
|
||||
if cluster_resolver.cluster_spec().as_dict():
|
||||
|
|
|
|||
|
|
@ -269,7 +269,10 @@ def _get_ps_strategy_creator(num_workers,
|
|||
return _create_parameter_server
|
||||
|
||||
|
||||
def _deferred_pool_runner(has_chief, num_workers, initializer=None):
|
||||
def _deferred_pool_runner(has_chief,
|
||||
num_workers,
|
||||
initializer=None,
|
||||
share_gpu=True):
|
||||
"""Returns a callable that returns the pool runner.
|
||||
|
||||
It creates the pool runner only upon first invocation. This avoids creating it
|
||||
|
|
@ -279,6 +282,7 @@ def _deferred_pool_runner(has_chief, num_workers, initializer=None):
|
|||
has_chief: whether there should be a chief.
|
||||
num_workers: the number of workers excluding the chief.
|
||||
initializer: initializer of each process.
|
||||
share_gpu: whether to share GPU between the workers.
|
||||
|
||||
Returns:
|
||||
A callable that returns the runner.
|
||||
|
|
@ -294,7 +298,7 @@ def _deferred_pool_runner(has_chief, num_workers, initializer=None):
|
|||
num_ps=0,
|
||||
has_eval=False)
|
||||
runner = multi_process_runner.MultiProcessPoolRunner(
|
||||
cluster_spec, initializer=initializer)
|
||||
cluster_spec, initializer=initializer, share_gpu=share_gpu)
|
||||
container.append(runner)
|
||||
return container[0]
|
||||
|
||||
|
|
@ -307,6 +311,14 @@ _two_worker_pool = _deferred_pool_runner(
|
|||
has_chief=True,
|
||||
num_workers=1,
|
||||
initializer=_get_multi_worker_mirrored_creator(required_gpus=0))
|
||||
|
||||
# Two-worker pool where each worker gets it's own GPU. Useful for testing MWMS
|
||||
# on a single host.
|
||||
_two_worker_pool_noshare = _deferred_pool_runner(
|
||||
has_chief=True,
|
||||
num_workers=1,
|
||||
initializer=_get_multi_worker_mirrored_creator(required_gpus=0),
|
||||
share_gpu=False)
|
||||
_four_worker_pool = _deferred_pool_runner(
|
||||
has_chief=True,
|
||||
num_workers=3,
|
||||
|
|
@ -413,7 +425,18 @@ multi_worker_mirrored_2x1_gpu = combinations.NamedDistribution(
|
|||
num_workers=1,
|
||||
required_gpus=1,
|
||||
pool_runner_fn=_two_worker_pool,
|
||||
no_xla=True,
|
||||
share_gpu=False,
|
||||
)
|
||||
|
||||
# Same as above, but not sharing the GPU between the workers.
|
||||
multi_worker_mirrored_2x1_gpu_noshare = combinations.NamedDistribution(
|
||||
"MultiWorkerMirrored2x1GPUNoShare",
|
||||
_get_multi_worker_mirrored_creator(required_gpus=1),
|
||||
has_chief=True,
|
||||
num_workers=1,
|
||||
required_gpus=1,
|
||||
pool_runner_fn=_two_worker_pool_noshare,
|
||||
share_gpu=False,
|
||||
)
|
||||
# chief + 1 worker, with 2 GPU each.
|
||||
multi_worker_mirrored_2x2_gpu = combinations.NamedDistribution(
|
||||
|
|
@ -602,6 +625,9 @@ tf_export(
|
|||
tf_export(
|
||||
_TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu",
|
||||
v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu")
|
||||
tf_export(
|
||||
_TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu_noshare",
|
||||
v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu_noshare")
|
||||
tf_export(
|
||||
_TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu",
|
||||
v1=[]).export_constant(__name__, "multi_worker_mirrored_2x2_gpu")
|
||||
|
|
|
|||
|
|
@ -48,6 +48,10 @@ tf_module {
|
|||
name: "multi_worker_mirrored_2x1_gpu"
|
||||
mtype: "<class \'tensorflow.python.distribute.combinations.NamedDistribution\'>"
|
||||
}
|
||||
member {
|
||||
name: "multi_worker_mirrored_2x1_gpu_noshare"
|
||||
mtype: "<class \'tensorflow.python.distribute.combinations.NamedDistribution\'>"
|
||||
}
|
||||
member {
|
||||
name: "multi_worker_mirrored_2x2_gpu"
|
||||
mtype: "<class \'tensorflow.python.distribute.combinations.NamedDistribution\'>"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user