mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Unify dtensor accelerator initialization to a single API.
dtensor.initialize_accelerator_systems(): - Reinitialization only allowed after shutdown_accelerator_system() is called. - Replaces initialize_tpu_system() and initialize_multi_client(). dtensor.shutdown_accelerator_systems(): - Only supported in single-controller mode due to TF distributed runtime limitation. - replaces shutdown_tpu_system() Clarify definition of - single-controller: job_name == localhost, num_clients == 1, and - multi-client: job_name != localhost, arbitrary num_clients. PiperOrigin-RevId: 473017056
This commit is contained in:
parent
81dc1867ea
commit
25449cfe85
|
|
@ -53,6 +53,13 @@ This release contains contributions from many people at Google, as well as:
|
||||||
* RNG behavior change for `tf.keras.initializers`. Keras initializers will now use stateless random ops to generate random numbers.
|
* RNG behavior change for `tf.keras.initializers`. Keras initializers will now use stateless random ops to generate random numbers.
|
||||||
* Both seeded and unseeded initializers will always generate the same values every time they are called (for a given variable shape). For unseeded initializers (`seed=None`), a random seed will be created and assigned at initializer creation (different initializer instances get different seeds).
|
* Both seeded and unseeded initializers will always generate the same values every time they are called (for a given variable shape). For unseeded initializers (`seed=None`), a random seed will be created and assigned at initializer creation (different initializer instances get different seeds).
|
||||||
* An unseeded initializer will raise a warning if it is reused (called) multiple times. This is because it would produce the same values each time, which may not be intended.
|
* An unseeded initializer will raise a warning if it is reused (called) multiple times. This is because it would produce the same values each time, which may not be intended.
|
||||||
|
* API changes under `tf.experimental.dtensor`:
|
||||||
|
* New API for initialization of CPU/GPU/TPU in dtensor.
|
||||||
|
`dtensor.initialize_accelerator_system` and
|
||||||
|
`dtensor.shutdown_accelerator_system`.
|
||||||
|
* The following existing API will be removed:
|
||||||
|
`dtensor.initialize_multi_client`, `dtensor.initialize_tpu_system`, and
|
||||||
|
`dtensor.shutdown_tpu_system`.
|
||||||
|
|
||||||
## Deprecations
|
## Deprecations
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -192,25 +192,32 @@ pytype_strict_library(
|
||||||
"//tensorflow/dtensor:dtensor-users",
|
"//tensorflow/dtensor:dtensor-users",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":accelerator_util",
|
||||||
":api",
|
":api",
|
||||||
":config",
|
":config",
|
||||||
":layout",
|
":layout",
|
||||||
":multi_client_util",
|
|
||||||
":tpu_util",
|
":tpu_util",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:config",
|
"//tensorflow/python:config",
|
||||||
"//tensorflow/python:device",
|
"//tensorflow/python:device",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/framework:tfrt_utils",
|
||||||
"//tensorflow/python/util:tf_export",
|
"//tensorflow/python/util:tf_export",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/flags",
|
||||||
"@absl_py//absl/logging",
|
"@absl_py//absl/logging",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(b/245589661): Split accelerator_util to its module after
|
||||||
|
# The circular dependence is removed with dtensor_initialize_tpu_system.
|
||||||
pytype_strict_library(
|
pytype_strict_library(
|
||||||
name = "tpu_util",
|
name = "tpu_util",
|
||||||
srcs = ["tpu_util.py"],
|
srcs = [
|
||||||
|
"accelerator_util.py",
|
||||||
|
"tpu_util.py",
|
||||||
|
],
|
||||||
visibility = default_visibility + [
|
visibility = default_visibility + [
|
||||||
"//tensorflow/dtensor:dtensor-users",
|
"//tensorflow/dtensor:dtensor-users",
|
||||||
],
|
],
|
||||||
|
|
@ -221,9 +228,8 @@ pytype_strict_library(
|
||||||
":gen_dtensor_ops",
|
":gen_dtensor_ops",
|
||||||
":heartbeat",
|
":heartbeat",
|
||||||
":layout",
|
":layout",
|
||||||
":multi_client_util",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:device",
|
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
|
@ -231,12 +237,14 @@ pytype_strict_library(
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:function",
|
"//tensorflow/python/eager:function",
|
||||||
|
"//tensorflow/python/framework:config",
|
||||||
"//tensorflow/python/framework:constant_op",
|
"//tensorflow/python/framework:constant_op",
|
||||||
"//tensorflow/python/framework:tfrt_utils",
|
"//tensorflow/python/framework:tfrt_utils",
|
||||||
"//tensorflow/python/tpu:topology",
|
"//tensorflow/python/tpu:topology",
|
||||||
"//tensorflow/python/util:tf_export",
|
"//tensorflow/python/util:tf_export",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/flags",
|
"@absl_py//absl/flags",
|
||||||
|
"@absl_py//absl/logging",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -256,14 +264,10 @@ pytype_strict_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
pytype_strict_library(
|
pytype_strict_library(
|
||||||
name = "multi_client_util",
|
name = "accelerator_util",
|
||||||
srcs = ["multi_client_util.py"],
|
srcs = [],
|
||||||
deps = [
|
deps = [
|
||||||
":config",
|
":tpu_util",
|
||||||
"//tensorflow/core:protos_all_py",
|
|
||||||
"//tensorflow/python:platform",
|
|
||||||
"//tensorflow/python/eager:context",
|
|
||||||
"@absl_py//absl/logging",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
235
tensorflow/dtensor/python/accelerator_util.py
Normal file
235
tensorflow/dtensor/python/accelerator_util.py
Normal file
|
|
@ -0,0 +1,235 @@
|
||||||
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utility for working with accelerator systems."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import cluster_pb2
|
||||||
|
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||||
|
from tensorflow.dtensor.python import api
|
||||||
|
from tensorflow.dtensor.python import config
|
||||||
|
from tensorflow.dtensor.python import tpu_util
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import config as tf_config
|
||||||
|
from tensorflow.python.framework import tfrt_utils
|
||||||
|
from tensorflow.python.platform import remote_utils
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
_INITIALIZED_ACCELERATOR_SYSTEM_TYPE = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_initialized() -> bool:
|
||||||
|
"""Returns whether accelerator system has been initialized."""
|
||||||
|
return bool(_INITIALIZED_ACCELERATOR_SYSTEM_TYPE)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_multi_client_cluster(job_name: str,
|
||||||
|
dtensor_jobs: List[str],
|
||||||
|
client_id: int,
|
||||||
|
collective_leader: str,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
enable_coordination_service: bool = False):
|
||||||
|
"""Initialize GRPC servers and collectives for multi-client DTensor setup.
|
||||||
|
|
||||||
|
This function can be used to initialize a multi-client cluster and enable
|
||||||
|
collective ops. GRPC servers are necessary in the multi-client mode, even
|
||||||
|
when the number of clientis is 1.
|
||||||
|
|
||||||
|
NOTE: this function must be called in an eager context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_name: The job name used by all clients in the DTensor cluster.
|
||||||
|
dtensor_jobs: A list of the DTensor client jobs participating in the
|
||||||
|
cluster. Must be strings of the form "hostname:port".
|
||||||
|
client_id: The ID of the DTensor client this function is being called in.
|
||||||
|
collective_leader: The job/task that will be used to run collectives.
|
||||||
|
port: The port this client's GRPC server will run on. If omitted, use the
|
||||||
|
port from dtensor_jobs for this client.
|
||||||
|
enable_coordination_service: If true, enable distributed coordination
|
||||||
|
service to make sure that workers know the devices on each other, a
|
||||||
|
prerequisite for data transfer through cross-worker rendezvous.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If running inside a tf.function.
|
||||||
|
"""
|
||||||
|
assert context.executing_eagerly()
|
||||||
|
|
||||||
|
if not collective_leader.startswith("/job:"):
|
||||||
|
collective_leader = "/job:" + collective_leader
|
||||||
|
|
||||||
|
config_proto = context.get_config()
|
||||||
|
config_proto.experimental.collective_group_leader = collective_leader
|
||||||
|
# Construct server def from the host directly instead of relying on
|
||||||
|
# TF_CONFIG.
|
||||||
|
cluster_def = cluster_pb2.ClusterDef()
|
||||||
|
# Note that for bns addresses, we will currently rely on the sorted string
|
||||||
|
# of job name as the order of assigning task ids. This might be brittle once
|
||||||
|
# we have jobs across multiple cells.
|
||||||
|
cluster_def.job.add(name=job_name, tasks=dict(enumerate(dtensor_jobs)))
|
||||||
|
server_def = tensorflow_server_pb2.ServerDef(
|
||||||
|
cluster=cluster_def,
|
||||||
|
default_session_config=config_proto,
|
||||||
|
job_name=job_name,
|
||||||
|
task_index=client_id,
|
||||||
|
protocol=remote_utils.get_default_communication_protocol(),
|
||||||
|
port=port)
|
||||||
|
server_def.default_session_config.rpc_options.num_channels_per_target = 4
|
||||||
|
server_def.default_session_config.experimental.recv_buf_max_chunk = -1
|
||||||
|
|
||||||
|
context.context().configure_collective_ops(
|
||||||
|
collective_leader=collective_leader)
|
||||||
|
if enable_coordination_service:
|
||||||
|
context.context().configure_coordination_service(
|
||||||
|
service_type="standalone", service_leader=collective_leader)
|
||||||
|
|
||||||
|
logging.info("Enabling collectives with server_def: %s", server_def)
|
||||||
|
context.context().enable_collective_ops(server_def)
|
||||||
|
context.ensure_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_tpu_runtime():
|
||||||
|
was_enabled = context.is_tfrt_enabled()
|
||||||
|
if ("tpu_use_tfrt" in flags.FLAGS and flags.FLAGS["tpu_use_tfrt"].value):
|
||||||
|
tfrt_utils.set_tfrt_enabled(True)
|
||||||
|
if not was_enabled:
|
||||||
|
context._reset_context() # pylint:disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(
|
||||||
|
"experimental.dtensor.initialize_accelerator_system",
|
||||||
|
"experimental.dtensor.initialize_tpu_system",
|
||||||
|
"experimental.dtensor.initialize_multi_client",
|
||||||
|
v1=[])
|
||||||
|
def initialize_accelerator_system(
|
||||||
|
device_type: Optional[str] = None,
|
||||||
|
enable_coordination_service: Optional[bool] = False) -> str:
|
||||||
|
"""Initializes accelerators and communication fabrics for DTensor.
|
||||||
|
|
||||||
|
DTensor configures TensorFlow to run in the local mode or multi-client mode.
|
||||||
|
- In local mode, a mesh can only use devices attached to the current process.
|
||||||
|
- In multi-client mode, a mesh can span across devices from multiple clients.
|
||||||
|
|
||||||
|
If `DTENSOR_JOBS` is non-empty, DTensor configures TensorFlow to run in the
|
||||||
|
multi-client mode using the distributed runtime. In multi-client mode devices
|
||||||
|
on different clients can communicate with each other.
|
||||||
|
|
||||||
|
The following environment variables controls the behavior of this function.
|
||||||
|
|
||||||
|
- `DTENSOR_JOBS`: string, a comma separated list. Each item in the list is
|
||||||
|
of format `{hostname}:{port}`. If empty, DTensor runs in the local mode.
|
||||||
|
Examples of valid `DTENSOR_JOBS` values:
|
||||||
|
- 4 clients on localhost:
|
||||||
|
`localhost:10000,localhost:10001,localhost:10002,localhost:10003`
|
||||||
|
- 2 clients on host1, 2 clients on host2
|
||||||
|
`host1:10000,host1:10001,host2:10000,host2:10003`
|
||||||
|
If the hostnames are BNS addresses, the items must be sorted in
|
||||||
|
alphabetical order.
|
||||||
|
- `DTENSOR_CLIENT_ID`: integer, between `0` to `num_clients - 1`, to identify
|
||||||
|
the client id of the current process. The default value is `0`.
|
||||||
|
- `DTENSOR_JOB_NAME`: string, a string for the name of the TensorFlow job.
|
||||||
|
The job name controls the job name section of the TensorFlow DeviceSpecs,
|
||||||
|
e.g., `job:worker` in `/job:worker/replica:0/task:0/device:TPU:0` when
|
||||||
|
the job name is `worker`.
|
||||||
|
The default value is `localhost` in local mode, and
|
||||||
|
`worker` when in the multi-client mode. All DTensor clients within the
|
||||||
|
same multi-client cluster share the same job name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_type: Type of accelerator to use, can be CPU, GPU, or TPU. If None,
|
||||||
|
uses `tf.experimental.dtensor.preferred_device_type()`.
|
||||||
|
enable_coordination_service: If true, enable distributed coordination
|
||||||
|
service to make sure that workers know the devices on each other, when
|
||||||
|
there is more than 1 client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
device_type: the type of accelerator that was initialized.
|
||||||
|
"""
|
||||||
|
global _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
|
||||||
|
assert context.executing_eagerly()
|
||||||
|
|
||||||
|
if _INITIALIZED_ACCELERATOR_SYSTEM_TYPE:
|
||||||
|
raise ValueError(
|
||||||
|
"Accelerator system has already been initialized. "
|
||||||
|
"Call tf.experimental.dtensor.shutdown_acceerator_system() first.")
|
||||||
|
|
||||||
|
context.context()._clear_caches() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
if device_type is None:
|
||||||
|
device_type = config.preferred_device_type()
|
||||||
|
|
||||||
|
device_type = device_type.upper()
|
||||||
|
if device_type not in {"CPU", "GPU", "TPU"}:
|
||||||
|
raise ValueError(f"Unknown device_type {device_type}. "
|
||||||
|
"Allowed values are CPU, GPU, or TPU")
|
||||||
|
|
||||||
|
# Reconfigure TensorFlow to use TFRT TPU runtime if requested.
|
||||||
|
if device_type == "TPU":
|
||||||
|
_configure_tpu_runtime()
|
||||||
|
|
||||||
|
# Configure logical host CPU devices for accelerators.
|
||||||
|
if device_type in ("GPU", "TPU"):
|
||||||
|
num_local_devices = api.num_local_devices(device_type)
|
||||||
|
if api.num_local_devices("CPU") < num_local_devices:
|
||||||
|
tf_config.set_logical_device_configuration(
|
||||||
|
tf_config.list_physical_devices("CPU")[0],
|
||||||
|
[context.LogicalDeviceConfiguration()] * num_local_devices)
|
||||||
|
|
||||||
|
if not config.is_local_mode():
|
||||||
|
initialize_multi_client_cluster(
|
||||||
|
job_name=config.job_name(),
|
||||||
|
dtensor_jobs=config.jobs(),
|
||||||
|
client_id=config.client_id(),
|
||||||
|
collective_leader=config.full_job_name(task_id=0),
|
||||||
|
enable_coordination_service=enable_coordination_service)
|
||||||
|
|
||||||
|
if device_type == "TPU":
|
||||||
|
tpu_util.initialize_tpu_system()
|
||||||
|
|
||||||
|
_INITIALIZED_ACCELERATOR_SYSTEM_TYPE = device_type
|
||||||
|
|
||||||
|
return device_type
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(
|
||||||
|
"experimental.dtensor.shutdown_accelerator_system",
|
||||||
|
"experimental.dtensor.shutdown_tpu_system",
|
||||||
|
v1=[])
|
||||||
|
def shutdown_accelerator_system() -> None:
|
||||||
|
"""Shuts down the accelerator system."""
|
||||||
|
global _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
|
||||||
|
context.async_wait()
|
||||||
|
|
||||||
|
if not is_initialized():
|
||||||
|
raise ValueError(
|
||||||
|
"Accelerator system is not initialized. Call "
|
||||||
|
"tf.experimental.dtensor.initialize_accelerator_system first.")
|
||||||
|
|
||||||
|
device_type = _INITIALIZED_ACCELERATOR_SYSTEM_TYPE
|
||||||
|
|
||||||
|
if not config.is_local_mode():
|
||||||
|
raise ValueError(
|
||||||
|
"Shutting down accelerator system under multi-client mode is "
|
||||||
|
"not supported.")
|
||||||
|
|
||||||
|
if device_type == "TPU":
|
||||||
|
tpu_util.shutdown_tpu_system()
|
||||||
|
|
||||||
|
# reset TF context to stop gRPC servers.
|
||||||
|
context._reset_context() # pylint: disable=protected-access
|
||||||
|
context.context()._clear_caches() # pylint: disable=protected-access
|
||||||
|
_INITIALIZED_ACCELERATOR_SYSTEM_TYPE = None
|
||||||
|
|
@ -466,6 +466,19 @@ def num_global_devices(device_type: str) -> int:
|
||||||
# Private methods.
|
# Private methods.
|
||||||
|
|
||||||
|
|
||||||
|
def is_tpu_present() -> bool:
|
||||||
|
"""Returns true if TPU devices are present."""
|
||||||
|
# Check if TPU is present from initialized context.
|
||||||
|
# TPU_SYSTEM is a logical device that indicates TPUs are present.
|
||||||
|
tpu_system_devices = tf_config.list_physical_devices("TPU_SYSTEM")
|
||||||
|
return len(tpu_system_devices) > 0 # pylint: disable=g-explicit-length-test
|
||||||
|
|
||||||
|
|
||||||
|
def is_gpu_present() -> bool:
|
||||||
|
"""Returns true if TPU devices are present."""
|
||||||
|
return len(tf_config.list_physical_devices("GPU")) > 0 # pylint: disable=g-explicit-length-test
|
||||||
|
|
||||||
|
|
||||||
def _set_dtensor_device(device: dtensor_device.DTensorDevice) -> None:
|
def _set_dtensor_device(device: dtensor_device.DTensorDevice) -> None:
|
||||||
global _dtensor_singleton
|
global _dtensor_singleton
|
||||||
_dtensor_singleton = device
|
_dtensor_singleton = device
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from tensorflow.python.framework import config as tf_config
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
_DT_CLIENT_ID = "DTENSOR_CLIENT_ID"
|
_DT_CLIENT_ID = "DTENSOR_CLIENT_ID"
|
||||||
|
# DTENSOR_NUM_CLIENTS is removed, but some DTensor users still use this symbol.
|
||||||
_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS"
|
_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS"
|
||||||
_DT_JOB_NAME = "DTENSOR_JOB_NAME"
|
_DT_JOB_NAME = "DTENSOR_JOB_NAME"
|
||||||
_DT_JOBS = "DTENSOR_JOBS"
|
_DT_JOBS = "DTENSOR_JOBS"
|
||||||
|
|
@ -48,13 +49,9 @@ def client_id() -> int:
|
||||||
@tf_export("experimental.dtensor.num_clients", v1=[])
|
@tf_export("experimental.dtensor.num_clients", v1=[])
|
||||||
def num_clients() -> int:
|
def num_clients() -> int:
|
||||||
"""Returns the number of clients in this DTensor cluster."""
|
"""Returns the number of clients in this DTensor cluster."""
|
||||||
# If missing, assume running with a single client with num_clients of 1.
|
if is_local_mode():
|
||||||
num_clients_value = int(os.environ.get(_DT_NUM_CLIENTS, "1"))
|
return 1
|
||||||
if num_clients_value <= 0:
|
return len(jobs())
|
||||||
raise ValueError(f"Environment variable {_DT_NUM_CLIENTS} "
|
|
||||||
f"must be > 0, got {num_clients_value}.")
|
|
||||||
|
|
||||||
return num_clients_value
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("experimental.dtensor.job_name", v1=[])
|
@tf_export("experimental.dtensor.job_name", v1=[])
|
||||||
|
|
@ -123,6 +120,11 @@ def heartbeat_enabled() -> bool:
|
||||||
return os.environ.get(_DT_HEARTBEAT_ENABLED, "true").lower() in ("true", "1")
|
return os.environ.get(_DT_HEARTBEAT_ENABLED, "true").lower() in ("true", "1")
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_mode() -> bool:
|
||||||
|
"""Returns true if DTensor shall run in local mode."""
|
||||||
|
return not jobs()
|
||||||
|
|
||||||
|
|
||||||
def is_tpu_present() -> bool:
|
def is_tpu_present() -> bool:
|
||||||
"""Returns true if TPU devices are present."""
|
"""Returns true if TPU devices are present."""
|
||||||
# Check if TPU is present from initialized context.
|
# Check if TPU is present from initialized context.
|
||||||
|
|
@ -136,6 +138,7 @@ def is_gpu_present() -> bool:
|
||||||
return bool(tf_config.list_physical_devices("GPU"))
|
return bool(tf_config.list_physical_devices("GPU"))
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("experimental.dtensor.preferred_device_type", v1=[])
|
||||||
def preferred_device_type() -> str:
|
def preferred_device_type() -> str:
|
||||||
"""Returns the preferred device type for the accelerators.
|
"""Returns the preferred device type for the accelerators.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,10 @@ from typing import List, Optional, Tuple
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.dtensor.python import accelerator_util
|
||||||
from tensorflow.dtensor.python import api
|
from tensorflow.dtensor.python import api
|
||||||
from tensorflow.dtensor.python import config
|
from tensorflow.dtensor.python import config
|
||||||
from tensorflow.dtensor.python import layout
|
from tensorflow.dtensor.python import layout
|
||||||
from tensorflow.dtensor.python import multi_client_util
|
|
||||||
from tensorflow.dtensor.python import tpu_util
|
from tensorflow.dtensor.python import tpu_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import config as tf_config
|
from tensorflow.python.framework import config as tf_config
|
||||||
|
|
@ -49,6 +49,7 @@ def _make_device_specs(
|
||||||
device_type: Optional[str] = None
|
device_type: Optional[str] = None
|
||||||
) -> Tuple[List[tf_device.DeviceSpec], str]:
|
) -> Tuple[List[tf_device.DeviceSpec], str]:
|
||||||
"""Makes device specs from local devices names or number of global devices."""
|
"""Makes device specs from local devices names or number of global devices."""
|
||||||
|
|
||||||
if devices is None:
|
if devices is None:
|
||||||
if device_type is None:
|
if device_type is None:
|
||||||
device_type = 'CPU'
|
device_type = 'CPU'
|
||||||
|
|
@ -157,6 +158,10 @@ def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
|
||||||
"""
|
"""
|
||||||
dim_names, shape = zip(*mesh_dims)
|
dim_names, shape = zip(*mesh_dims)
|
||||||
|
|
||||||
|
if not accelerator_util.is_initialized():
|
||||||
|
raise ValueError('Accelerators are uninitialized, please run '
|
||||||
|
'dtensor.initialize_accelerator_system() first.')
|
||||||
|
|
||||||
if device_type and device_type.upper() == 'TPU':
|
if device_type and device_type.upper() == 'TPU':
|
||||||
# TODO(b/185940495): Allow multi-mesh and partial on TPU.
|
# TODO(b/185940495): Allow multi-mesh and partial on TPU.
|
||||||
# TPU meshes can only be configured through environment variables that
|
# TPU meshes can only be configured through environment variables that
|
||||||
|
|
@ -174,10 +179,6 @@ def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
|
||||||
# This is particularly useful on single clients when users want to create
|
# This is particularly useful on single clients when users want to create
|
||||||
# meshes that use fewer logical devices than what's available.
|
# meshes that use fewer logical devices than what's available.
|
||||||
|
|
||||||
if config.num_clients() > 1 and not multi_client_util.is_initialized():
|
|
||||||
raise ValueError('Invalid multi-client topology, please run '
|
|
||||||
'dtensor.initialize_multi_client() first.')
|
|
||||||
|
|
||||||
local_spec = tf_device.DeviceSpec(
|
local_spec = tf_device.DeviceSpec(
|
||||||
job=config.job_name(), replica=0, task=config.client_id())
|
job=config.job_name(), replica=0, task=config.client_id())
|
||||||
device_specs = [local_spec.make_merged_spec(d) for d in device_specs]
|
device_specs = [local_spec.make_merged_spec(d) for d in device_specs]
|
||||||
|
|
@ -217,60 +218,6 @@ def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
|
||||||
raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')
|
raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')
|
||||||
|
|
||||||
|
|
||||||
@tf_export('experimental.dtensor.initialize_multi_client', v1=[])
|
|
||||||
def dtensor_initialize_multi_client(
|
|
||||||
enable_coordination_service: Optional[bool] = False) -> None:
|
|
||||||
"""Initializes Multi Client DTensor.
|
|
||||||
|
|
||||||
The following environment variables controls the behavior of this function.
|
|
||||||
If the variables are unset, DTensor will be configured to run in single-client
|
|
||||||
mode.
|
|
||||||
|
|
||||||
- DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
|
|
||||||
client id of the current process. The default value is 0.
|
|
||||||
- DTENSOR_NUM_CLIENTS: integer, the number of clients. The default value is 1.
|
|
||||||
- DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
|
|
||||||
job. The default is `localhost` when number of clients is 1, and `worker`
|
|
||||||
when the number of clients is greater than 1.
|
|
||||||
The job name controls the job name section of the TensorFlow DeviceSpecs,
|
|
||||||
e.g., `job:worker` in `/job:worker/replica:0/task:0/device:TPU:0` when
|
|
||||||
the job name is `worker`.
|
|
||||||
- DTENSOR_JOBS: string, a comma separated list. Each item in the list is
|
|
||||||
of format `{hostname}:{port}` and the items must be sorted in alphabet
|
|
||||||
order. The implication is the RPC port numbers of the clients from
|
|
||||||
the same host must be ordered by the client ID.
|
|
||||||
Examples of valid DTENSOR_JOBS values:
|
|
||||||
- 4 clients on localhost:
|
|
||||||
`localhost:10000,localhost:10001,localhost:10002,localhost:10003`
|
|
||||||
- 2 clients on host1, 2 clients on host2
|
|
||||||
`host1:10000,host1:10001,host2:10000,host2:10003`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
enable_coordination_service: If true, enable distributed coordination
|
|
||||||
service to make sure that workers know the devices on each other, a
|
|
||||||
prerequisite for data transfer through cross-worker rendezvous.
|
|
||||||
"""
|
|
||||||
assert context.executing_eagerly()
|
|
||||||
|
|
||||||
# Collective GRPC servers are only necessary in multi-client setup.
|
|
||||||
# Single clients can use local mode of collectives.
|
|
||||||
if config.num_clients() > 1:
|
|
||||||
multi_client_util.initialize_multi_client_cluster(
|
|
||||||
job_name=config.job_name(),
|
|
||||||
dtensor_jobs=config.jobs(),
|
|
||||||
client_id=config.client_id(),
|
|
||||||
collective_leader=config.full_job_name(task_id=0),
|
|
||||||
enable_coordination_service=enable_coordination_service)
|
|
||||||
|
|
||||||
# Make sure the server change is fully propagated before returning.
|
|
||||||
context.ensure_initialized()
|
|
||||||
context.async_wait()
|
|
||||||
context.context()._clear_caches() # pylint: disable=protected-access
|
|
||||||
|
|
||||||
# Unlike TPU, do not enable heartbeat service.
|
|
||||||
# They tend to interfere with regular GPU/CPU collective Ops.
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export('experimental.dtensor.barrier', v1=[])
|
@tf_export('experimental.dtensor.barrier', v1=[])
|
||||||
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
|
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
|
||||||
"""Runs a barrier on the mesh.
|
"""Runs a barrier on the mesh.
|
||||||
|
|
|
||||||
|
|
@ -1,115 +0,0 @@
|
||||||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Utility wrappers for working with multi-client setups."""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from absl import logging
|
|
||||||
|
|
||||||
from tensorflow.core.protobuf import cluster_pb2
|
|
||||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
|
||||||
from tensorflow.dtensor.python import config
|
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.platform import remote_utils
|
|
||||||
|
|
||||||
_is_multi_client_initialized = False
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_multi_client_cluster(job_name: str,
|
|
||||||
dtensor_jobs: List[str],
|
|
||||||
client_id: int,
|
|
||||||
collective_leader: str,
|
|
||||||
port: Optional[int] = None,
|
|
||||||
enable_coordination_service: bool = False):
|
|
||||||
"""Initialize GRPC servers and collectives for multi-client DTensor setup.
|
|
||||||
|
|
||||||
While single clients (e.g. Forge) can use local mode of collectives, GRPC
|
|
||||||
servers are necessary in mutli-client setup. This function can be used to
|
|
||||||
initialize a cluster and enable collective ops.
|
|
||||||
|
|
||||||
NOTE: this function must be called in an eager context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
job_name: The job name used by all clients in the DTensor cluster.
|
|
||||||
dtensor_jobs: A list of the DTensor client jobs participating in the
|
|
||||||
cluster. Must be strings of the form "hostname:port".
|
|
||||||
client_id: The ID of the DTensor client this function is being called in.
|
|
||||||
collective_leader: The job/task that will be used to run collectives.
|
|
||||||
port: The port this client's GRPC server will run on.
|
|
||||||
enable_coordination_service: If true, enable distributed coordination
|
|
||||||
service to make sure that workers know the devices on each other, a
|
|
||||||
prerequisite for data transfer through cross-worker rendezvous.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If running inside a tf.function.
|
|
||||||
"""
|
|
||||||
global _is_multi_client_initialized
|
|
||||||
assert context.executing_eagerly()
|
|
||||||
|
|
||||||
if _is_multi_client_initialized:
|
|
||||||
raise ValueError("Multi-client mode has already been initialized.")
|
|
||||||
|
|
||||||
if config.num_clients() <= 1:
|
|
||||||
raise ValueError(
|
|
||||||
"DTENSOR_NUM_CLIENTS must be set greater than 1 for multi-client mode.")
|
|
||||||
|
|
||||||
if not config.jobs() or len(config.jobs()) <= 1:
|
|
||||||
raise ValueError(
|
|
||||||
"DTENSOR_JOBS environment variable is required when using multi-client "
|
|
||||||
"mode to properly set up communications between servers.")
|
|
||||||
|
|
||||||
if len(config.jobs()) != config.num_clients():
|
|
||||||
raise ValueError(
|
|
||||||
"DTENSOR_JOBS environment variable must be configured with the same "
|
|
||||||
"number of items as DTENSOR_NUM_CLIENTS.")
|
|
||||||
|
|
||||||
if not collective_leader.startswith("/job:"):
|
|
||||||
collective_leader = "/job:" + collective_leader
|
|
||||||
|
|
||||||
context.context().configure_collective_ops(
|
|
||||||
collective_leader=collective_leader)
|
|
||||||
if enable_coordination_service:
|
|
||||||
context.context().configure_coordination_service(
|
|
||||||
service_type="standalone", service_leader=collective_leader)
|
|
||||||
|
|
||||||
config_proto = context.get_config()
|
|
||||||
config_proto.experimental.collective_group_leader = collective_leader
|
|
||||||
# Construct server def from the host directly instead of relying on
|
|
||||||
# TF_CONFIG.
|
|
||||||
cluster_def = cluster_pb2.ClusterDef()
|
|
||||||
# Note that we will currently rely on the sorted string of job name as the
|
|
||||||
# order of assigning task ids. This might be brittle once we have jobs
|
|
||||||
# across multiple cells.
|
|
||||||
cluster_def.job.add(name=job_name, tasks=dict(enumerate(dtensor_jobs)))
|
|
||||||
server_def = tensorflow_server_pb2.ServerDef(
|
|
||||||
cluster=cluster_def,
|
|
||||||
default_session_config=config_proto,
|
|
||||||
job_name=job_name,
|
|
||||||
task_index=client_id,
|
|
||||||
protocol=remote_utils.get_default_communication_protocol(),
|
|
||||||
port=port)
|
|
||||||
server_def.default_session_config.rpc_options.num_channels_per_target = 4
|
|
||||||
server_def.default_session_config.experimental.recv_buf_max_chunk = -1
|
|
||||||
|
|
||||||
logging.info("Enabling collectives with server_def: %s", server_def)
|
|
||||||
context.context().enable_collective_ops(server_def)
|
|
||||||
context.ensure_initialized()
|
|
||||||
|
|
||||||
_is_multi_client_initialized = True
|
|
||||||
|
|
||||||
|
|
||||||
def is_initialized() -> bool:
|
|
||||||
"""Returns whether multi-client mode has been initialized."""
|
|
||||||
return _is_multi_client_initialized
|
|
||||||
|
|
@ -18,7 +18,6 @@ import functools
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
from absl import flags
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.dtensor.python import api
|
from tensorflow.dtensor.python import api
|
||||||
|
|
@ -27,22 +26,18 @@ from tensorflow.dtensor.python import dtensor_device
|
||||||
from tensorflow.dtensor.python import gen_dtensor_ops
|
from tensorflow.dtensor.python import gen_dtensor_ops
|
||||||
from tensorflow.dtensor.python import heartbeat
|
from tensorflow.dtensor.python import heartbeat
|
||||||
from tensorflow.dtensor.python import layout as layout_lib
|
from tensorflow.dtensor.python import layout as layout_lib
|
||||||
from tensorflow.dtensor.python import multi_client_util
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device as tf_device
|
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tfrt_utils
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.tpu import topology
|
from tensorflow.python.tpu import topology
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
|
|
||||||
_INITIALIZED_TPU_SYSTEMS = {}
|
|
||||||
_MESH_DIM_X = "x"
|
_MESH_DIM_X = "x"
|
||||||
_TPU_DEVICE_TYPE = "TPU"
|
_TPU_DEVICE_TYPE = "TPU"
|
||||||
|
|
||||||
|
|
@ -132,9 +127,8 @@ def _create_tpu_topology(core_locations: List[_CoreLocation], num_tasks: int,
|
||||||
mesh_shape=mesh_shape, device_coordinates=device_coordinates)
|
mesh_shape=mesh_shape, device_coordinates=device_coordinates)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("experimental.dtensor.shutdown_tpu_system", v1=[])
|
def shutdown_tpu_system():
|
||||||
def dtensor_shutdown_tpu_system():
|
"""Shuts down the TPU system."""
|
||||||
"""Shutdown TPU system."""
|
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def _shutdown_tpu_system():
|
def _shutdown_tpu_system():
|
||||||
|
|
@ -147,39 +141,8 @@ def dtensor_shutdown_tpu_system():
|
||||||
logging.warning("TPU system fails to shut down.")
|
logging.warning("TPU system fails to shut down.")
|
||||||
|
|
||||||
|
|
||||||
@tf_export("experimental.dtensor.initialize_tpu_system", v1=[])
|
def initialize_tpu_system():
|
||||||
def dtensor_initialize_tpu_system(enable_coordination_service=False):
|
"""Initializes the TPU system."""
|
||||||
"""Initialize the TPU devices.
|
|
||||||
|
|
||||||
This functions performs additional TPU related initialization after
|
|
||||||
calling `dtensor.initialize_multi_client` to initialize multi-client DTensor.
|
|
||||||
Refer to `dtensor.initialize_multi_client` for relevant environment
|
|
||||||
variables that controls the initialization of multi-client DTensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
enable_coordination_service: If true, enable distributed coordination
|
|
||||||
service to make sure that workers know the devices on each other, a
|
|
||||||
prerequisite for data transfer through cross-worker rendezvous.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If running inside a tf.function.
|
|
||||||
NotFoundError: If no TPU devices found in eager mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert context.executing_eagerly()
|
|
||||||
|
|
||||||
# Reconfigure TensorFlow to use TFRT TPU runtime if requested.
|
|
||||||
_configure_tpu_runtime()
|
|
||||||
|
|
||||||
# Collective GRPC servers are only necessary in mutli-client setup.
|
|
||||||
# Single clients can use local mode of collectives.
|
|
||||||
if config.num_clients() > 1 and not multi_client_util.is_initialized():
|
|
||||||
multi_client_util.initialize_multi_client_cluster(
|
|
||||||
job_name=config.job_name(),
|
|
||||||
dtensor_jobs=config.jobs(),
|
|
||||||
client_id=config.client_id(),
|
|
||||||
collective_leader=config.full_job_name(task_id=0),
|
|
||||||
enable_coordination_service=enable_coordination_service)
|
|
||||||
|
|
||||||
# Make sure the server change is fully propagated before attempting to run
|
# Make sure the server change is fully propagated before attempting to run
|
||||||
# the core ID merging logic below.
|
# the core ID merging logic below.
|
||||||
|
|
@ -199,14 +162,6 @@ def dtensor_initialize_tpu_system(enable_coordination_service=False):
|
||||||
with ops.device("/job:" + config.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access
|
with ops.device("/job:" + config.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access
|
||||||
my_core_ids = _tpu_init_fn()
|
my_core_ids = _tpu_init_fn()
|
||||||
logging.info("TPU core IDs: %s", my_core_ids)
|
logging.info("TPU core IDs: %s", my_core_ids)
|
||||||
context.initialize_logical_devices()
|
|
||||||
|
|
||||||
# Configure virtual CPUs that is 1:1 mapped to TPU cores.
|
|
||||||
context.context().set_logical_cpu_devices(
|
|
||||||
len(api.local_devices(_TPU_DEVICE_TYPE)),
|
|
||||||
tf_device.DeviceSpec(
|
|
||||||
job=config.job_name(), replica=0,
|
|
||||||
task=config.client_id()).to_string())
|
|
||||||
|
|
||||||
# `my_core_ids` contains the IDs of TPU cores attached to this host.
|
# `my_core_ids` contains the IDs of TPU cores attached to this host.
|
||||||
#
|
#
|
||||||
|
|
@ -727,8 +682,8 @@ def create_tpu_mesh(mesh_dim_names: List[str],
|
||||||
# easier interaction with the C++ API.
|
# easier interaction with the C++ API.
|
||||||
global_core_locations = [l.to_list() for l in global_core_locations]
|
global_core_locations = [l.to_list() for l in global_core_locations]
|
||||||
if _dtensor_device is None:
|
if _dtensor_device is None:
|
||||||
raise ValueError(
|
raise ValueError("Invalid system device, "
|
||||||
"Invalid system device, run dtensor.initialize_tpu_system() first")
|
"run dtensor.initialize_accelerator_system() first")
|
||||||
global_core_ids = _dtensor_device.tpu_core_locations_to_ids(
|
global_core_ids = _dtensor_device.tpu_core_locations_to_ids(
|
||||||
global_core_locations)
|
global_core_locations)
|
||||||
|
|
||||||
|
|
@ -809,9 +764,16 @@ def get_device_locations(
|
||||||
"Looking up other clients' device locations is not supported")
|
"Looking up other clients' device locations is not supported")
|
||||||
|
|
||||||
|
|
||||||
def _configure_tpu_runtime():
|
# TODO(b/245589661): Remove dtensor_initialize_tpu_system() and
|
||||||
was_enabled = context.is_tfrt_enabled()
|
# dtensor_shutdown_tpu_system() after users stopped using them.
|
||||||
if ("tpu_use_tfrt" in flags.FLAGS and flags.FLAGS["tpu_use_tfrt"].value):
|
def dtensor_initialize_tpu_system(enable_coordination_service=False):
|
||||||
tfrt_utils.set_tfrt_enabled(True)
|
"""Deprecated way to initialize the TPU system."""
|
||||||
if not was_enabled:
|
from . import accelerator_util # pylint: disable=g-import-not-at-top
|
||||||
context._reset_context() # pylint:disable=protected-access
|
accelerator_util.initialize_accelerator_system(
|
||||||
|
"TPU", enable_coordination_service=enable_coordination_service)
|
||||||
|
|
||||||
|
|
||||||
|
def dtensor_shutdown_tpu_system():
|
||||||
|
"""Deprecated way to shutodwn the TPU system."""
|
||||||
|
from . import accelerator_util # pylint: disable=g-import-not-at-top
|
||||||
|
accelerator_util.shutdown_accelerator_system()
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ def gen_api_init_files(
|
||||||
compat_init_templates = [],
|
compat_init_templates = [],
|
||||||
packages = [
|
packages = [
|
||||||
"tensorflow.python",
|
"tensorflow.python",
|
||||||
|
"tensorflow.dtensor.python.accelerator_util",
|
||||||
"tensorflow.dtensor.python.api",
|
"tensorflow.dtensor.python.api",
|
||||||
"tensorflow.dtensor.python.config",
|
"tensorflow.dtensor.python.config",
|
||||||
"tensorflow.dtensor.python.d_checkpoint",
|
"tensorflow.dtensor.python.d_checkpoint",
|
||||||
|
|
@ -53,7 +54,6 @@ def gen_api_init_files(
|
||||||
"tensorflow.dtensor.python.layout",
|
"tensorflow.dtensor.python.layout",
|
||||||
"tensorflow.dtensor.python.mesh_util",
|
"tensorflow.dtensor.python.mesh_util",
|
||||||
"tensorflow.dtensor.python.save_restore",
|
"tensorflow.dtensor.python.save_restore",
|
||||||
"tensorflow.dtensor.python.tpu_util",
|
|
||||||
"tensorflow.lite.python.analyzer",
|
"tensorflow.lite.python.analyzer",
|
||||||
"tensorflow.lite.python.lite",
|
"tensorflow.lite.python.lite",
|
||||||
"tensorflow.lite.python.authoring.authoring",
|
"tensorflow.lite.python.authoring.authoring",
|
||||||
|
|
|
||||||
|
|
@ -76,13 +76,17 @@ tf_module {
|
||||||
name: "heartbeat_enabled"
|
name: "heartbeat_enabled"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "initialize_accelerator_system"
|
||||||
|
argspec: "args=[\'device_type\', \'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "initialize_multi_client"
|
name: "initialize_multi_client"
|
||||||
argspec: "args=[\'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
argspec: "args=[\'device_type\', \'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "initialize_tpu_system"
|
name: "initialize_tpu_system"
|
||||||
argspec: "args=[\'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
argspec: "args=[\'device_type\', \'enable_coordination_service\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "job_name"
|
name: "job_name"
|
||||||
|
|
@ -120,6 +124,10 @@ tf_module {
|
||||||
name: "pack"
|
name: "pack"
|
||||||
argspec: "args=[\'tensors\', \'layout\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'tensors\', \'layout\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "preferred_device_type"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "relayout"
|
name: "relayout"
|
||||||
argspec: "args=[\'tensor\', \'layout\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'tensor\', \'layout\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
|
@ -132,6 +140,10 @@ tf_module {
|
||||||
name: "sharded_save"
|
name: "sharded_save"
|
||||||
argspec: "args=[\'mesh\', \'file_prefix\', \'tensor_names\', \'shape_and_slices\', \'tensors\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'mesh\', \'file_prefix\', \'tensor_names\', \'shape_and_slices\', \'tensors\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "shutdown_accelerator_system"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "shutdown_tpu_system"
|
name: "shutdown_tpu_system"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user