Add dtensor.default_mesh() as a replacement for dtensor.run_on().

dtensor.run_on() tends to cause confusion since it indicates override semantics (force the functions and ops in the context to run on the supplied mesh), even though the implementation only provides fallback semantics.

PiperOrigin-RevId: 513949632
This commit is contained in:
Srujun Thanmay Gupta 2023-03-03 16:24:56 -08:00 committed by TensorFlower Gardener
parent 4471e075b3
commit 85151a471c
4 changed files with 36 additions and 0 deletions

View File

@ -82,6 +82,13 @@
retrieving the worker index from within a worker, when using parameter
server training with a custom training loop.
* `tf.experimental.dtensor`:
* Deprecated `dtensor.run_on` in favor of `dtensor.default_mesh` to
correctly indicate that the context does not override the mesh that the
ops and functions will run on, it only sets a fallback default mesh.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:

View File

@ -48,6 +48,7 @@ pytype_strict_library(
":layout",
"//tensorflow/python/eager:context",
"//tensorflow/python/framework:ops",
"//tensorflow/python/util",
"//tensorflow/python/util:tf_export",
],
)

View File

@ -23,6 +23,7 @@ from tensorflow.dtensor.python import gen_dtensor_ops
from tensorflow.dtensor.python import layout as layout_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
_dtensor_singleton = None
@ -73,6 +74,7 @@ def call_with_layout(fn: Callable[...,
@tf_export("experimental.dtensor.run_on", v1=[])
@deprecation.deprecated(None, "Use `dtensor.default_mesh` scope instead.")
@contextlib.contextmanager
def run_on(mesh: layout_lib.Mesh):
"""Runs enclosed functions in the DTensor device scope.
@ -89,6 +91,28 @@ def run_on(mesh: layout_lib.Mesh):
Yields:
A context in which all ops and tf.functions will run on the DTensor device.
"""
with default_mesh(mesh):
yield
@tf_export("experimental.dtensor.default_mesh", v1=[])
@contextlib.contextmanager
def default_mesh(mesh: layout_lib.Mesh):
"""Sets the default DTensor device mesh to use for enclosed functions.
This function returns a scope. All the ops and tf.functions in this scope will
default to this DTensor mesh if a mesh cannot be inferred from any of the
inputs
This is useful for wrapping any tf.function that doesn't take a DTensor as
input but would like to produce DTensor as result. The scope will also make
sure all small constants are replicated as DTensors.
Args:
mesh: A Mesh instance to extract a default mesh from.
Yields:
A context in which all ops and tf.functions will run on the given mesh.
"""
if not isinstance(mesh, layout_lib.Mesh):
raise ValueError(f"Expect `mesh` to be `Mesh`, got {type(mesh)}")

View File

@ -60,6 +60,10 @@ tf_module {
name: "create_tpu_mesh"
argspec: "args=[\'mesh_dim_names\', \'mesh_shape\', \'mesh_name\', \'ring_dims\', \'ring_axes\', \'ring_bounds\', \'can_split_host_across_rings\', \'build_ring_across_rings\', \'rotate_ring_across_rings\', \'use_xla_spmd\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "default_mesh"
argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
}
member_method {
name: "device_name"
argspec: "args=[], varargs=None, keywords=None, defaults=None"