mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Use merge_call for optimizer distributed_apply if strategy doesn't support merge_call free path.
In principle even if the strategy doesn't support merge_call free path, distributed_apply can still be called w/o merge_call. However it would be slower than calling it with merge_call in pure eager mode, mainly because MS uses multiple python threads to run replica function, and those python threads are running sequentially. PiperOrigin-RevId: 367305737 Change-Id: Ie2316fd62131c60eff605cf9f51e044c09ec9abc
This commit is contained in:
parent
fb37439d64
commit
fcda86f6b8
|
|
@ -2562,13 +2562,21 @@ class StrategyExtendedV2(object):
|
|||
where each list has an element per replica, and the caller is responsible
|
||||
for ensuring all elements are executed.
|
||||
"""
|
||||
_require_cross_replica_or_default_context_extended(self)
|
||||
# TODO(b/178944108): Update the documentation to relfect the fact that
|
||||
# `update` can be called in a replica context.
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
fn = autograph.tf_convert(
|
||||
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
|
||||
with self._container_strategy().scope():
|
||||
return self._update(var, fn, args, kwargs, group)
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
# pylint: disable=protected-access
|
||||
if (replica_context is None or replica_context is
|
||||
distribution_strategy_context._get_default_replica_context()):
|
||||
fn = autograph.tf_convert(
|
||||
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
|
||||
with self._container_strategy().scope():
|
||||
return self._update(var, fn, args, kwargs, group)
|
||||
else:
|
||||
return self._replica_ctx_update(
|
||||
var, fn, args=args, kwargs=kwargs, group=group)
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ from tensorflow.python.framework import ops
|
|||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
|
@ -773,10 +772,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||
def _update(self, var, fn, args, kwargs, group):
|
||||
# TODO(josh11b): In eager mode, use one thread per device.
|
||||
assert isinstance(var, values.DistributedVariable)
|
||||
if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and
|
||||
var.aggregation != variables_lib.VariableAggregation.NONE):
|
||||
distribute_utils.assert_mirrored(args)
|
||||
distribute_utils.assert_mirrored(kwargs)
|
||||
updates = []
|
||||
for i, v in enumerate(var.values):
|
||||
name = "update_%d" % i
|
||||
|
|
@ -838,8 +833,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||
name = "update_%d" % i
|
||||
with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
|
||||
updates.append(
|
||||
fn(*distribute_utils.select_replica_mirrored(i, args),
|
||||
**distribute_utils.select_replica_mirrored(i, kwargs)))
|
||||
fn(*distribute_utils.select_replica(i, args),
|
||||
**distribute_utils.select_replica(i, kwargs)))
|
||||
return distribute_utils.update_regroup(self, updates, group)
|
||||
|
||||
def read_var(self, replica_local_var):
|
||||
|
|
|
|||
|
|
@ -663,8 +663,16 @@ class OptimizerV2(trackable.Trackable):
|
|||
grads_and_vars = self._aggregate_gradients(grads_and_vars)
|
||||
grads_and_vars = self._transform_gradients(grads_and_vars)
|
||||
|
||||
return self._distributed_apply(strategy, grads_and_vars, name,
|
||||
apply_state)
|
||||
if optimizer_utils.strategy_supports_no_merge_call():
|
||||
return self._distributed_apply(strategy, grads_and_vars, name,
|
||||
apply_state)
|
||||
else:
|
||||
return distribute_ctx.get_replica_context().merge_call(
|
||||
functools.partial(self._distributed_apply, apply_state=apply_state),
|
||||
args=(grads_and_vars,),
|
||||
kwargs={
|
||||
"name": name,
|
||||
})
|
||||
|
||||
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
|
||||
"""`apply_gradients` using a `DistributionStrategy`."""
|
||||
|
|
@ -703,8 +711,16 @@ class OptimizerV2(trackable.Trackable):
|
|||
with name_scope_only_in_function_or_graph(
|
||||
"update" if eagerly_outside_functions else "update_" +
|
||||
var.op.name):
|
||||
update_ops.append(distribute_ctx.get_replica_context()._update( # pylint: disable=protected-access
|
||||
var, apply_grad_to_update_var, args=(grad,), group=False))
|
||||
update_op = distribution.extended.update(
|
||||
var, apply_grad_to_update_var, args=(grad,), group=False)
|
||||
if distribute_ctx.in_cross_replica_context():
|
||||
# In cross-replica context, extended.update returns a list of
|
||||
# update ops from all replicas (group=False).
|
||||
update_ops.extend(update_op)
|
||||
else:
|
||||
# In replica context, extended.update return the single update op
|
||||
# of current replica.
|
||||
update_ops.append(update_op)
|
||||
|
||||
any_symbolic = any(isinstance(i, ops.Operation) or
|
||||
tf_utils.is_symbolic_tensor(i) for i in update_ops)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user