Unify dtensor accelerator initialization to a single API.

dtensor.initialize_accelerator_systems():
 - Reinitialization only allowed after shutdown_accelerator_system() is called.
 - Replaces initialize_tpu_system() and initialize_multi_client().

dtensor.shutdown_accelerator_systems():
 - Only supported in single-controller mode due to TF distributed runtime limitation.
 - replaces shutdown_tpu_system()

Clarify definition of
- single-controller: job_name == localhost, num_clients == 1,
and
- multi-client: job_name != localhost, arbitrary num_clients.

PiperOrigin-RevId: 473017056
This commit is contained in:
A. Unique TensorFlower 2022-09-08 10:01:08 -07:00 committed by TensorFlower Gardener
parent 81dc1867ea
commit 25449cfe85
10 changed files with 321 additions and 253 deletions

View File

@ -53,6 +53,13 @@ This release contains contributions from many people at Google, as well as:
* RNG behavior change for `tf.keras.initializers`. Keras initializers will now use stateless random ops to generate random numbers. * RNG behavior change for `tf.keras.initializers`. Keras initializers will now use stateless random ops to generate random numbers.
* Both seeded and unseeded initializers will always generate the same values every time they are called (for a given variable shape). For unseeded initializers (`seed=None`), a random seed will be created and assigned at initializer creation (different initializer instances get different seeds). * Both seeded and unseeded initializers will always generate the same values every time they are called (for a given variable shape). For unseeded initializers (`seed=None`), a random seed will be created and assigned at initializer creation (different initializer instances get different seeds).
* An unseeded initializer will raise a warning if it is reused (called) multiple times. This is because it would produce the same values each time, which may not be intended. * An unseeded initializer will raise a warning if it is reused (called) multiple times. This is because it would produce the same values each time, which may not be intended.
* API changes under `tf.experimental.dtensor`:
* New API for initialization of CPU/GPU/TPU in dtensor.
`dtensor.initialize_accelerator_system` and
`dtensor.shutdown_accelerator_system`.
* The following existing API will be removed:
`dtensor.initialize_multi_client`, `dtensor.initialize_tpu_system`, and
`dtensor.shutdown_tpu_system`.
## Deprecations ## Deprecations

View File

@ -192,25 +192,32 @@ pytype_strict_library(
"//tensorflow/dtensor:dtensor-users", "//tensorflow/dtensor:dtensor-users",
], ],
deps = [ deps = [
":accelerator_util",
":api", ":api",
":config", ":config",
":layout", ":layout",
":multi_client_util",
":tpu_util", ":tpu_util",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:config", "//tensorflow/python:config",
"//tensorflow/python:device", "//tensorflow/python:device",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/framework:tfrt_utils",
"//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_export",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/flags",
"@absl_py//absl/logging", "@absl_py//absl/logging",
], ],
) )
# TODO(b/245589661): Split accelerator_util to its module after
# The circular dependence is removed with dtensor_initialize_tpu_system.
pytype_strict_library( pytype_strict_library(
name = "tpu_util", name = "tpu_util",
srcs = ["tpu_util.py"], srcs = [
"accelerator_util.py",
"tpu_util.py",
],
visibility = default_visibility + [ visibility = default_visibility + [
"//tensorflow/dtensor:dtensor-users", "//tensorflow/dtensor:dtensor-users",
], ],
@ -221,9 +228,8 @@ pytype_strict_library(
":gen_dtensor_ops", ":gen_dtensor_ops",
":heartbeat", ":heartbeat",
":layout", ":layout",
":multi_client_util", "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:device",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
@ -231,12 +237,14 @@ pytype_strict_library(
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function", "//tensorflow/python/eager:function",
"//tensorflow/python/framework:config",
"//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:tfrt_utils", "//tensorflow/python/framework:tfrt_utils",
"//tensorflow/python/tpu:topology", "//tensorflow/python/tpu:topology",
"//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_export",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/flags", "@absl_py//absl/flags",
"@absl_py//absl/logging",
], ],
) )
@ -256,14 +264,10 @@ pytype_strict_library(
) )
pytype_strict_library( pytype_strict_library(
name = "multi_client_util", name = "accelerator_util",
srcs = ["multi_client_util.py"], srcs = [],
deps = [ deps = [
":config", ":tpu_util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
"//tensorflow/python/eager:context",
"@absl_py//absl/logging",
], ],
) )

View File

@ -0,0 +1,235 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility for working with accelerator systems."""
from typing import List, Optional
from absl import flags
from absl import logging
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.dtensor.python import api
from tensorflow.dtensor.python import config
from tensorflow.dtensor.python import tpu_util
from tensorflow.python.eager import context
from tensorflow.python.framework import config as tf_config
from tensorflow.python.framework import tfrt_utils
from tensorflow.python.platform import remote_utils
from tensorflow.python.util.tf_export import tf_export
_INITIALIZED_ACCELERATOR_SYSTEM_TYPE = None
def is_initialized() -> bool:
"""Returns whether accelerator system has been initialized."""
return bool(_INITIALIZED_ACCELERATOR_SYSTEM_TYPE)
def initialize_multi_client_cluster(job_name: str,
dtensor_jobs: List[str],
client_id: int,
collective_leader: str,
port: Optional[int] = None,
enable_coordination_service: bool = False):
"""Initialize GRPC servers and collectives for multi-client DTensor setup.
This function can be used to initialize a multi-client cluster and enable
collective ops. GRPC servers are necessary in the multi-client mode, even
when the number of clientis is 1.
NOTE: this function must be called in an eager context.
Args:
job_name: The job name used by all clients in the DTensor cluster.
dtensor_jobs: A list of the DTensor client jobs participating in the
cluster. Must be strings of the form "hostname:port".
client_id: The ID of the DTensor client this function is being called in.
collective_leader: The job/task that will be used to run collectives.
port: The port this client's GRPC server will run on. If omitted, use the
port from dtensor_jobs for this client.
enable_coordination_service: If true, enable distributed coordination
service to make sure that workers know the devices on each other, a
prerequisite for data transfer through cross-worker rendezvous.
Raises:
RuntimeError: If running inside a tf.function.
"""
assert context.executing_eagerly()
if not collective_leader.startswith("/job:"):
collective_leader = "/job:" + collective_leader
config_proto = context.get_config()
config_proto.experimental.collective_group_leader = collective_leader
# Construct server def from the host directly instead of relying on
# TF_CONFIG.
cluster_def = cluster_pb2.ClusterDef()
# Note that for bns addresses, we will currently rely on the sorted string
# of job name as the order of assigning task ids. This might be brittle once
# we have jobs across multiple cells.
cluster_def.job.add(name=job_name, tasks=dict(enumerate(dtensor_jobs)))
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def,
default_session_config=config_proto,
job_name=job_name,
task_index=client_id,
protocol=remote_utils.get_default_communication_protocol(),
port=port)
server_def.default_session_config.rpc_options.num_channels_per_target = 4
server_def.default_session_config.experimental.recv_buf_max_chunk = -1
context.context().configure_collective_ops(
collective_leader=collective_leader)
if enable_coordination_service:
context.context().configure_coordination_service(
service_type="standalone", service_leader=collective_leader)
logging.info("Enabling collectives with server_def: %s", server_def)
context.context().enable_collective_ops(server_def)
context.ensure_initialized()
def _configure_tpu_runtime():
was_enabled = context.is_tfrt_enabled()
if ("tpu_use_tfrt" in flags.FLAGS and flags.FLAGS["tpu_use_tfrt"].value):
tfrt_utils.set_tfrt_enabled(True)
if not was_enabled:
context._reset_context() # pylint:disable=protected-access
@tf_export(
"experimental.dtensor.initialize_accelerator_system",
"experimental.dtensor.initialize_tpu_system",
"experimental.dtensor.initialize_multi_client",
v1=[])
def initialize_accelerator_system(
device_type: Optional[str] = None,
enable_coordination_service: Optional[bool] = False) -> str:
"""Initializes accelerators and communication fabrics for DTensor.
DTensor configures TensorFlow to run in the local mode or multi-client mode.
- In local mode, a mesh can only use devices attached to the current process.
- In multi-client mode, a mesh can span across devices from multiple clients.
If `DTENSOR_JOBS` is non-empty, DTensor configures TensorFlow to run in the
multi-client mode using the distributed runtime. In multi-client mode devices
on different clients can communicate with each other.
The following environment variables controls the behavior of this function.
- `DTENSOR_JOBS`: string, a comma separated list. Each item in the list is
of format `{hostname}:{port}`. If empty, DTensor runs in the local mode.
Examples of valid `DTENSOR_JOBS` values:
- 4 clients on localhost:
`localhost:10000,localhost:10001,localhost:10002,localhost:10003`
- 2 clients on host1, 2 clients on host2
`host1:10000,host1:10001,host2:10000,host2:10003`
If the hostnames are BNS addresses, the items must be sorted in
alphabetical order.
- `DTENSOR_CLIENT_ID`: integer, between `0` to `num_clients - 1`, to identify
the client id of the current process. The default value is `0`.
- `DTENSOR_JOB_NAME`: string, a string for the name of the TensorFlow job.
The job name controls the job name section of the TensorFlow DeviceSpecs,
e.g., `job:worker` in `/job:worker/replica:0/task:0/device:TPU:0` when
the job name is `worker`.
The default value is `localhost` in local mode, and
`worker` when in the multi-client mode. All DTensor clients within the
same multi-client cluster share the same job name.
Args:
device_type: Type of accelerator to use, can be CPU, GPU, or TPU. If None,
uses `tf.experimental.dtensor.preferred_device_type()`.
enable_coordination_service: If true, enable distributed coordination
service to make sure that workers know the devices on each other, when
there is more than 1 client.
Returns:
device_type: the type of accelerator that was initialized.
"""
global _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
assert context.executing_eagerly()
if _INITIALIZED_ACCELERATOR_SYSTEM_TYPE:
raise ValueError(
"Accelerator system has already been initialized. "
"Call tf.experimental.dtensor.shutdown_acceerator_system() first.")
context.context()._clear_caches() # pylint: disable=protected-access
if device_type is None:
device_type = config.preferred_device_type()
device_type = device_type.upper()
if device_type not in {"CPU", "GPU", "TPU"}:
raise ValueError(f"Unknown device_type {device_type}. "
"Allowed values are CPU, GPU, or TPU")
# Reconfigure TensorFlow to use TFRT TPU runtime if requested.
if device_type == "TPU":
_configure_tpu_runtime()
# Configure logical host CPU devices for accelerators.
if device_type in ("GPU", "TPU"):
num_local_devices = api.num_local_devices(device_type)
if api.num_local_devices("CPU") < num_local_devices:
tf_config.set_logical_device_configuration(
tf_config.list_physical_devices("CPU")[0],
[context.LogicalDeviceConfiguration()] * num_local_devices)
if not config.is_local_mode():
initialize_multi_client_cluster(
job_name=config.job_name(),
dtensor_jobs=config.jobs(),
client_id=config.client_id(),
collective_leader=config.full_job_name(task_id=0),
enable_coordination_service=enable_coordination_service)
if device_type == "TPU":
tpu_util.initialize_tpu_system()
_INITIALIZED_ACCELERATOR_SYSTEM_TYPE = device_type
return device_type
@tf_export(
"experimental.dtensor.shutdown_accelerator_system",
"experimental.dtensor.shutdown_tpu_system",
v1=[])
def shutdown_accelerator_system() -> None:
"""Shuts down the accelerator system."""
global _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
context.async_wait()
if not is_initialized():
raise ValueError(
"Accelerator system is not initialized. Call "
"tf.experimental.dtensor.initialize_accelerator_system first.")
device_type = _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
if not config.is_local_mode():
raise ValueError(
"Shutting down accelerator system under multi-client mode is "
"not supported.")
if device_type == "TPU":
tpu_util.shutdown_tpu_system()
# reset TF context to stop gRPC servers.
context._reset_context() # pylint: disable=protected-access
context.context()._clear_caches() # pylint: disable=protected-access
_INITIALIZED_ACCELERATOR_SYSTEM_TYPE = None

View File

@ -466,6 +466,19 @@ def num_global_devices(device_type: str) -> int:
# Private methods. # Private methods.
def is_tpu_present() -> bool:
"""Returns true if TPU devices are present."""
# Check if TPU is present from initialized context.
# TPU_SYSTEM is a logical device that indicates TPUs are present.
tpu_system_devices = tf_config.list_physical_devices("TPU_SYSTEM")
return len(tpu_system_devices) > 0 # pylint: disable=g-explicit-length-test
def is_gpu_present() -> bool:
"""Returns true if TPU devices are present."""
return len(tf_config.list_physical_devices("GPU")) > 0 # pylint: disable=g-explicit-length-test
def _set_dtensor_device(device: dtensor_device.DTensorDevice) -> None: def _set_dtensor_device(device: dtensor_device.DTensorDevice) -> None:
global _dtensor_singleton global _dtensor_singleton
_dtensor_singleton = device _dtensor_singleton = device

View File

@ -21,6 +21,7 @@ from tensorflow.python.framework import config as tf_config
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
_DT_CLIENT_ID = "DTENSOR_CLIENT_ID" _DT_CLIENT_ID = "DTENSOR_CLIENT_ID"
# DTENSOR_NUM_CLIENTS is removed, but some DTensor users still use this symbol.
_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS" _DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS"
_DT_JOB_NAME = "DTENSOR_JOB_NAME" _DT_JOB_NAME = "DTENSOR_JOB_NAME"
_DT_JOBS = "DTENSOR_JOBS" _DT_JOBS = "DTENSOR_JOBS"
@ -48,13 +49,9 @@ def client_id() -> int:
@tf_export("experimental.dtensor.num_clients", v1=[]) @tf_export("experimental.dtensor.num_clients", v1=[])
def num_clients() -> int: def num_clients() -> int:
"""Returns the number of clients in this DTensor cluster.""" """Returns the number of clients in this DTensor cluster."""
# If missing, assume running with a single client with num_clients of 1. if is_local_mode():
num_clients_value = int(os.environ.get(_DT_NUM_CLIENTS, "1")) return 1
if num_clients_value <= 0: return len(jobs())
raise ValueError(f"Environment variable {_DT_NUM_CLIENTS} "
f"must be > 0, got {num_clients_value}.")
return num_clients_value
@tf_export("experimental.dtensor.job_name", v1=[]) @tf_export("experimental.dtensor.job_name", v1=[])
@ -123,6 +120,11 @@ def heartbeat_enabled() -> bool:
return os.environ.get(_DT_HEARTBEAT_ENABLED, "true").lower() in ("true", "1") return os.environ.get(_DT_HEARTBEAT_ENABLED, "true").lower() in ("true", "1")
def is_local_mode() -> bool:
"""Returns true if DTensor shall run in local mode."""
return not jobs()
def is_tpu_present() -> bool: def is_tpu_present() -> bool:
"""Returns true if TPU devices are present.""" """Returns true if TPU devices are present."""
# Check if TPU is present from initialized context. # Check if TPU is present from initialized context.
@ -136,6 +138,7 @@ def is_gpu_present() -> bool:
return bool(tf_config.list_physical_devices("GPU")) return bool(tf_config.list_physical_devices("GPU"))
@tf_export("experimental.dtensor.preferred_device_type", v1=[])
def preferred_device_type() -> str: def preferred_device_type() -> str:
"""Returns the preferred device type for the accelerators. """Returns the preferred device type for the accelerators.

View File

@ -18,10 +18,10 @@ from typing import List, Optional, Tuple
from absl import logging from absl import logging
import numpy as np import numpy as np
from tensorflow.dtensor.python import accelerator_util
from tensorflow.dtensor.python import api from tensorflow.dtensor.python import api
from tensorflow.dtensor.python import config from tensorflow.dtensor.python import config
from tensorflow.dtensor.python import layout from tensorflow.dtensor.python import layout
from tensorflow.dtensor.python import multi_client_util
from tensorflow.dtensor.python import tpu_util from tensorflow.dtensor.python import tpu_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import config as tf_config from tensorflow.python.framework import config as tf_config
@ -49,6 +49,7 @@ def _make_device_specs(
device_type: Optional[str] = None device_type: Optional[str] = None
) -> Tuple[List[tf_device.DeviceSpec], str]: ) -> Tuple[List[tf_device.DeviceSpec], str]:
"""Makes device specs from local devices names or number of global devices.""" """Makes device specs from local devices names or number of global devices."""
if devices is None: if devices is None:
if device_type is None: if device_type is None:
device_type = 'CPU' device_type = 'CPU'
@ -157,6 +158,10 @@ def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
""" """
dim_names, shape = zip(*mesh_dims) dim_names, shape = zip(*mesh_dims)
if not accelerator_util.is_initialized():
raise ValueError('Accelerators are uninitialized, please run '
'dtensor.initialize_accelerator_system() first.')
if device_type and device_type.upper() == 'TPU': if device_type and device_type.upper() == 'TPU':
# TODO(b/185940495): Allow multi-mesh and partial on TPU. # TODO(b/185940495): Allow multi-mesh and partial on TPU.
# TPU meshes can only be configured through environment variables that # TPU meshes can only be configured through environment variables that
@ -174,10 +179,6 @@ def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
# This is particularly useful on single clients when users want to create # This is particularly useful on single clients when users want to create
# meshes that use fewer logical devices than what's available. # meshes that use fewer logical devices than what's available.
if config.num_clients() > 1 and not multi_client_util.is_initialized():
raise ValueError('Invalid multi-client topology, please run '
'dtensor.initialize_multi_client() first.')
local_spec = tf_device.DeviceSpec( local_spec = tf_device.DeviceSpec(
job=config.job_name(), replica=0, task=config.client_id()) job=config.job_name(), replica=0, task=config.client_id())
device_specs = [local_spec.make_merged_spec(d) for d in device_specs] device_specs = [local_spec.make_merged_spec(d) for d in device_specs]
@ -217,60 +218,6 @@ def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU') raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')
@tf_export('experimental.dtensor.initialize_multi_client', v1=[])
def dtensor_initialize_multi_client(
enable_coordination_service: Optional[bool] = False) -> None:
"""Initializes Multi Client DTensor.
The following environment variables controls the behavior of this function.
If the variables are unset, DTensor will be configured to run in single-client
mode.
- DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
client id of the current process. The default value is 0.
- DTENSOR_NUM_CLIENTS: integer, the number of clients. The default value is 1.
- DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
job. The default is `localhost` when number of clients is 1, and `worker`
when the number of clients is greater than 1.
The job name controls the job name section of the TensorFlow DeviceSpecs,
e.g., `job:worker` in `/job:worker/replica:0/task:0/device:TPU:0` when
the job name is `worker`.
- DTENSOR_JOBS: string, a comma separated list. Each item in the list is
of format `{hostname}:{port}` and the items must be sorted in alphabet
order. The implication is the RPC port numbers of the clients from
the same host must be ordered by the client ID.
Examples of valid DTENSOR_JOBS values:
- 4 clients on localhost:
`localhost:10000,localhost:10001,localhost:10002,localhost:10003`
- 2 clients on host1, 2 clients on host2
`host1:10000,host1:10001,host2:10000,host2:10003`
Args:
enable_coordination_service: If true, enable distributed coordination
service to make sure that workers know the devices on each other, a
prerequisite for data transfer through cross-worker rendezvous.
"""
assert context.executing_eagerly()
# Collective GRPC servers are only necessary in multi-client setup.
# Single clients can use local mode of collectives.
if config.num_clients() > 1:
multi_client_util.initialize_multi_client_cluster(
job_name=config.job_name(),
dtensor_jobs=config.jobs(),
client_id=config.client_id(),
collective_leader=config.full_job_name(task_id=0),
enable_coordination_service=enable_coordination_service)
# Make sure the server change is fully propagated before returning.
context.ensure_initialized()
context.async_wait()
context.context()._clear_caches() # pylint: disable=protected-access
# Unlike TPU, do not enable heartbeat service.
# They tend to interfere with regular GPU/CPU collective Ops.
@tf_export('experimental.dtensor.barrier', v1=[]) @tf_export('experimental.dtensor.barrier', v1=[])
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None): def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
"""Runs a barrier on the mesh. """Runs a barrier on the mesh.

View File

@ -1,115 +0,0 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility wrappers for working with multi-client setups."""
from typing import List, Optional
from absl import logging
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.dtensor.python import config
from tensorflow.python.eager import context
from tensorflow.python.platform import remote_utils
_is_multi_client_initialized = False
def initialize_multi_client_cluster(job_name: str,
dtensor_jobs: List[str],
client_id: int,
collective_leader: str,
port: Optional[int] = None,
enable_coordination_service: bool = False):
"""Initialize GRPC servers and collectives for multi-client DTensor setup.
While single clients (e.g. Forge) can use local mode of collectives, GRPC
servers are necessary in mutli-client setup. This function can be used to
initialize a cluster and enable collective ops.
NOTE: this function must be called in an eager context.
Args:
job_name: The job name used by all clients in the DTensor cluster.
dtensor_jobs: A list of the DTensor client jobs participating in the
cluster. Must be strings of the form "hostname:port".
client_id: The ID of the DTensor client this function is being called in.
collective_leader: The job/task that will be used to run collectives.
port: The port this client's GRPC server will run on.
enable_coordination_service: If true, enable distributed coordination
service to make sure that workers know the devices on each other, a
prerequisite for data transfer through cross-worker rendezvous.
Raises:
RuntimeError: If running inside a tf.function.
"""
global _is_multi_client_initialized
assert context.executing_eagerly()
if _is_multi_client_initialized:
raise ValueError("Multi-client mode has already been initialized.")
if config.num_clients() <= 1:
raise ValueError(
"DTENSOR_NUM_CLIENTS must be set greater than 1 for multi-client mode.")
if not config.jobs() or len(config.jobs()) <= 1:
raise ValueError(
"DTENSOR_JOBS environment variable is required when using multi-client "
"mode to properly set up communications between servers.")
if len(config.jobs()) != config.num_clients():
raise ValueError(
"DTENSOR_JOBS environment variable must be configured with the same "
"number of items as DTENSOR_NUM_CLIENTS.")
if not collective_leader.startswith("/job:"):
collective_leader = "/job:" + collective_leader
context.context().configure_collective_ops(
collective_leader=collective_leader)
if enable_coordination_service:
context.context().configure_coordination_service(
service_type="standalone", service_leader=collective_leader)
config_proto = context.get_config()
config_proto.experimental.collective_group_leader = collective_leader
# Construct server def from the host directly instead of relying on
# TF_CONFIG.
cluster_def = cluster_pb2.ClusterDef()
# Note that we will currently rely on the sorted string of job name as the
# order of assigning task ids. This might be brittle once we have jobs
# across multiple cells.
cluster_def.job.add(name=job_name, tasks=dict(enumerate(dtensor_jobs)))
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def,
default_session_config=config_proto,
job_name=job_name,
task_index=client_id,
protocol=remote_utils.get_default_communication_protocol(),
port=port)
server_def.default_session_config.rpc_options.num_channels_per_target = 4
server_def.default_session_config.experimental.recv_buf_max_chunk = -1
logging.info("Enabling collectives with server_def: %s", server_def)
context.context().enable_collective_ops(server_def)
context.ensure_initialized()
_is_multi_client_initialized = True
def is_initialized() -> bool:
"""Returns whether multi-client mode has been initialized."""
return _is_multi_client_initialized

View File

@ -18,7 +18,6 @@ import functools
import time import time
from typing import List, Optional, Dict from typing import List, Optional, Dict
from absl import flags
import numpy as np import numpy as np
from tensorflow.dtensor.python import api from tensorflow.dtensor.python import api
@ -27,22 +26,18 @@ from tensorflow.dtensor.python import dtensor_device
from tensorflow.dtensor.python import gen_dtensor_ops from tensorflow.dtensor.python import gen_dtensor_ops
from tensorflow.dtensor.python import heartbeat from tensorflow.dtensor.python import heartbeat
from tensorflow.dtensor.python import layout as layout_lib from tensorflow.dtensor.python import layout as layout_lib
from tensorflow.dtensor.python import multi_client_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tfrt_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import topology from tensorflow.python.tpu import topology
from tensorflow.python.util.tf_export import tf_export
_INITIALIZED_TPU_SYSTEMS = {}
_MESH_DIM_X = "x" _MESH_DIM_X = "x"
_TPU_DEVICE_TYPE = "TPU" _TPU_DEVICE_TYPE = "TPU"
@ -132,9 +127,8 @@ def _create_tpu_topology(core_locations: List[_CoreLocation], num_tasks: int,
mesh_shape=mesh_shape, device_coordinates=device_coordinates) mesh_shape=mesh_shape, device_coordinates=device_coordinates)
@tf_export("experimental.dtensor.shutdown_tpu_system", v1=[]) def shutdown_tpu_system():
def dtensor_shutdown_tpu_system(): """Shuts down the TPU system."""
"""Shutdown TPU system."""
@def_function.function @def_function.function
def _shutdown_tpu_system(): def _shutdown_tpu_system():
@ -147,39 +141,8 @@ def dtensor_shutdown_tpu_system():
logging.warning("TPU system fails to shut down.") logging.warning("TPU system fails to shut down.")
@tf_export("experimental.dtensor.initialize_tpu_system", v1=[]) def initialize_tpu_system():
def dtensor_initialize_tpu_system(enable_coordination_service=False): """Initializes the TPU system."""
"""Initialize the TPU devices.
This functions performs additional TPU related initialization after
calling `dtensor.initialize_multi_client` to initialize multi-client DTensor.
Refer to `dtensor.initialize_multi_client` for relevant environment
variables that controls the initialization of multi-client DTensor.
Args:
enable_coordination_service: If true, enable distributed coordination
service to make sure that workers know the devices on each other, a
prerequisite for data transfer through cross-worker rendezvous.
Raises:
RuntimeError: If running inside a tf.function.
NotFoundError: If no TPU devices found in eager mode.
"""
assert context.executing_eagerly()
# Reconfigure TensorFlow to use TFRT TPU runtime if requested.
_configure_tpu_runtime()
# Collective GRPC servers are only necessary in mutli-client setup.
# Single clients can use local mode of collectives.
if config.num_clients() > 1 and not multi_client_util.is_initialized():
multi_client_util.initialize_multi_client_cluster(
job_name=config.job_name(),
dtensor_jobs=config.jobs(),
client_id=config.client_id(),
collective_leader=config.full_job_name(task_id=0),
enable_coordination_service=enable_coordination_service)
# Make sure the server change is fully propagated before attempting to run # Make sure the server change is fully propagated before attempting to run
# the core ID merging logic below. # the core ID merging logic below.
@ -199,14 +162,6 @@ def dtensor_initialize_tpu_system(enable_coordination_service=False):
with ops.device("/job:" + config.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access with ops.device("/job:" + config.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access
my_core_ids = _tpu_init_fn() my_core_ids = _tpu_init_fn()
logging.info("TPU core IDs: %s", my_core_ids) logging.info("TPU core IDs: %s", my_core_ids)
context.initialize_logical_devices()
# Configure virtual CPUs that is 1:1 mapped to TPU cores.
context.context().set_logical_cpu_devices(
len(api.local_devices(_TPU_DEVICE_TYPE)),
tf_device.DeviceSpec(
job=config.job_name(), replica=0,
task=config.client_id()).to_string())
# `my_core_ids` contains the IDs of TPU cores attached to this host. # `my_core_ids` contains the IDs of TPU cores attached to this host.
# #
@ -727,8 +682,8 @@ def create_tpu_mesh(mesh_dim_names: List[str],
# easier interaction with the C++ API. # easier interaction with the C++ API.
global_core_locations = [l.to_list() for l in global_core_locations] global_core_locations = [l.to_list() for l in global_core_locations]
if _dtensor_device is None: if _dtensor_device is None:
raise ValueError( raise ValueError("Invalid system device, "
"Invalid system device, run dtensor.initialize_tpu_system() first") "run dtensor.initialize_accelerator_system() first")
global_core_ids = _dtensor_device.tpu_core_locations_to_ids( global_core_ids = _dtensor_device.tpu_core_locations_to_ids(
global_core_locations) global_core_locations)
@ -809,9 +764,16 @@ def get_device_locations(
"Looking up other clients' device locations is not supported") "Looking up other clients' device locations is not supported")
def _configure_tpu_runtime(): # TODO(b/245589661): Remove dtensor_initialize_tpu_system() and
was_enabled = context.is_tfrt_enabled() # dtensor_shutdown_tpu_system() after users stopped using them.
if ("tpu_use_tfrt" in flags.FLAGS and flags.FLAGS["tpu_use_tfrt"].value): def dtensor_initialize_tpu_system(enable_coordination_service=False):
tfrt_utils.set_tfrt_enabled(True) """Deprecated way to initialize the TPU system."""
if not was_enabled: from . import accelerator_util # pylint: disable=g-import-not-at-top
context._reset_context() # pylint:disable=protected-access accelerator_util.initialize_accelerator_system(
"TPU", enable_coordination_service=enable_coordination_service)
def dtensor_shutdown_tpu_system():
"""Deprecated way to shutodwn the TPU system."""
from . import accelerator_util # pylint: disable=g-import-not-at-top
accelerator_util.shutdown_accelerator_system()

View File

@ -45,6 +45,7 @@ def gen_api_init_files(
compat_init_templates = [], compat_init_templates = [],
packages = [ packages = [
"tensorflow.python", "tensorflow.python",
"tensorflow.dtensor.python.accelerator_util",
"tensorflow.dtensor.python.api", "tensorflow.dtensor.python.api",
"tensorflow.dtensor.python.config", "tensorflow.dtensor.python.config",
"tensorflow.dtensor.python.d_checkpoint", "tensorflow.dtensor.python.d_checkpoint",
@ -53,7 +54,6 @@ def gen_api_init_files(
"tensorflow.dtensor.python.layout", "tensorflow.dtensor.python.layout",
"tensorflow.dtensor.python.mesh_util", "tensorflow.dtensor.python.mesh_util",
"tensorflow.dtensor.python.save_restore", "tensorflow.dtensor.python.save_restore",
"tensorflow.dtensor.python.tpu_util",
"tensorflow.lite.python.analyzer", "tensorflow.lite.python.analyzer",
"tensorflow.lite.python.lite", "tensorflow.lite.python.lite",
"tensorflow.lite.python.authoring.authoring", "tensorflow.lite.python.authoring.authoring",

View File

@ -76,13 +76,17 @@ tf_module {
name: "heartbeat_enabled" name: "heartbeat_enabled"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "initialize_accelerator_system"
argspec: "args=[\'device_type\', \'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method { member_method {
name: "initialize_multi_client" name: "initialize_multi_client"
argspec: "args=[\'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'False\'], " argspec: "args=[\'device_type\', \'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
} }
member_method { member_method {
name: "initialize_tpu_system" name: "initialize_tpu_system"
argspec: "args=[\'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'False\'], " argspec: "args=[\'device_type\', \'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
} }
member_method { member_method {
name: "job_name" name: "job_name"
@ -120,6 +124,10 @@ tf_module {
name: "pack" name: "pack"
argspec: "args=[\'tensors\', \'layout\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'tensors\', \'layout\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "preferred_device_type"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "relayout" name: "relayout"
argspec: "args=[\'tensor\', \'layout\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'tensor\', \'layout\'], varargs=None, keywords=None, defaults=None"
@ -132,6 +140,10 @@ tf_module {
name: "sharded_save" name: "sharded_save"
argspec: "args=[\'mesh\', \'file_prefix\', \'tensor_names\', \'shape_and_slices\', \'tensors\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'mesh\', \'file_prefix\', \'tensor_names\', \'shape_and_slices\', \'tensors\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "shutdown_accelerator_system"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "shutdown_tpu_system" name: "shutdown_tpu_system"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"