mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add dtensor.default_mesh() as a replacement for dtensor.run_on().
dtensor.run_on() tends to cause confusion since it indicates override semantics (force the functions and ops in the context to run on the supplied mesh), even though the implementation only provides fallback semantics. PiperOrigin-RevId: 513949632
This commit is contained in:
parent
4471e075b3
commit
85151a471c
|
|
@ -82,6 +82,13 @@
|
|||
retrieving the worker index from within a worker, when using parameter
|
||||
server training with a custom training loop.
|
||||
|
||||
* `tf.experimental.dtensor`:
|
||||
|
||||
* Deprecated `dtensor.run_on` in favor of `dtensor.default_mesh` to
|
||||
correctly indicate that the context does not override the mesh that the
|
||||
ops and functions will run on, it only sets a fallback default mesh.
|
||||
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ pytype_strict_library(
|
|||
":layout",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/util",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from tensorflow.dtensor.python import gen_dtensor_ops
|
|||
from tensorflow.dtensor.python import layout as layout_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_dtensor_singleton = None
|
||||
|
|
@ -73,6 +74,7 @@ def call_with_layout(fn: Callable[...,
|
|||
|
||||
|
||||
@tf_export("experimental.dtensor.run_on", v1=[])
|
||||
@deprecation.deprecated(None, "Use `dtensor.default_mesh` scope instead.")
|
||||
@contextlib.contextmanager
|
||||
def run_on(mesh: layout_lib.Mesh):
|
||||
"""Runs enclosed functions in the DTensor device scope.
|
||||
|
|
@ -89,6 +91,28 @@ def run_on(mesh: layout_lib.Mesh):
|
|||
Yields:
|
||||
A context in which all ops and tf.functions will run on the DTensor device.
|
||||
"""
|
||||
with default_mesh(mesh):
|
||||
yield
|
||||
|
||||
|
||||
@tf_export("experimental.dtensor.default_mesh", v1=[])
|
||||
@contextlib.contextmanager
|
||||
def default_mesh(mesh: layout_lib.Mesh):
|
||||
"""Sets the default DTensor device mesh to use for enclosed functions.
|
||||
|
||||
This function returns a scope. All the ops and tf.functions in this scope will
|
||||
default to this DTensor mesh if a mesh cannot be inferred from any of the
|
||||
inputs
|
||||
This is useful for wrapping any tf.function that doesn't take a DTensor as
|
||||
input but would like to produce DTensor as result. The scope will also make
|
||||
sure all small constants are replicated as DTensors.
|
||||
|
||||
Args:
|
||||
mesh: A Mesh instance to extract a default mesh from.
|
||||
|
||||
Yields:
|
||||
A context in which all ops and tf.functions will run on the given mesh.
|
||||
"""
|
||||
if not isinstance(mesh, layout_lib.Mesh):
|
||||
raise ValueError(f"Expect `mesh` to be `Mesh`, got {type(mesh)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -60,6 +60,10 @@ tf_module {
|
|||
name: "create_tpu_mesh"
|
||||
argspec: "args=[\'mesh_dim_names\', \'mesh_shape\', \'mesh_name\', \'ring_dims\', \'ring_axes\', \'ring_bounds\', \'can_split_host_across_rings\', \'build_ring_across_rings\', \'rotate_ring_across_rings\', \'use_xla_spmd\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'False\', \'False\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "default_mesh"
|
||||
argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "device_name"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user