Use merge_call for optimizer distributed_apply if strategy doesn't support merge_call free path.

In principle even if the strategy doesn't support merge_call free path, distributed_apply can still be called w/o merge_call. However it would be slower than calling it with merge_call in pure eager mode, mainly because MS uses multiple python threads to run replica function, and those python threads are running sequentially.

PiperOrigin-RevId: 367305737
Change-Id: Ie2316fd62131c60eff605cf9f51e044c09ec9abc
This commit is contained in:
Chenkai Kuang 2021-04-07 15:02:57 -07:00
parent fb37439d64
commit fcda86f6b8
3 changed files with 35 additions and 16 deletions

View File

@ -2562,13 +2562,21 @@ class StrategyExtendedV2(object):
where each list has an element per replica, and the caller is responsible
for ensuring all elements are executed.
"""
_require_cross_replica_or_default_context_extended(self)
# TODO(b/178944108): Update the documentation to relfect the fact that
# `update` can be called in a replica context.
if kwargs is None:
kwargs = {}
fn = autograph.tf_convert(
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
with self._container_strategy().scope():
return self._update(var, fn, args, kwargs, group)
replica_context = distribution_strategy_context.get_replica_context()
# pylint: disable=protected-access
if (replica_context is None or replica_context is
distribution_strategy_context._get_default_replica_context()):
fn = autograph.tf_convert(
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
with self._container_strategy().scope():
return self._update(var, fn, args, kwargs, group)
else:
return self._replica_ctx_update(
var, fn, args=args, kwargs=kwargs, group=group)
def _update(self, var, fn, args, kwargs, group):
raise NotImplementedError("must be implemented in descendants")

View File

@ -46,7 +46,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -773,10 +772,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def _update(self, var, fn, args, kwargs, group):
# TODO(josh11b): In eager mode, use one thread per device.
assert isinstance(var, values.DistributedVariable)
if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and
var.aggregation != variables_lib.VariableAggregation.NONE):
distribute_utils.assert_mirrored(args)
distribute_utils.assert_mirrored(kwargs)
updates = []
for i, v in enumerate(var.values):
name = "update_%d" % i
@ -838,8 +833,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
name = "update_%d" % i
with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
updates.append(
fn(*distribute_utils.select_replica_mirrored(i, args),
**distribute_utils.select_replica_mirrored(i, kwargs)))
fn(*distribute_utils.select_replica(i, args),
**distribute_utils.select_replica(i, kwargs)))
return distribute_utils.update_regroup(self, updates, group)
def read_var(self, replica_local_var):

View File

@ -663,8 +663,16 @@ class OptimizerV2(trackable.Trackable):
grads_and_vars = self._aggregate_gradients(grads_and_vars)
grads_and_vars = self._transform_gradients(grads_and_vars)
return self._distributed_apply(strategy, grads_and_vars, name,
apply_state)
if optimizer_utils.strategy_supports_no_merge_call():
return self._distributed_apply(strategy, grads_and_vars, name,
apply_state)
else:
return distribute_ctx.get_replica_context().merge_call(
functools.partial(self._distributed_apply, apply_state=apply_state),
args=(grads_and_vars,),
kwargs={
"name": name,
})
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
"""`apply_gradients` using a `DistributionStrategy`."""
@ -703,8 +711,16 @@ class OptimizerV2(trackable.Trackable):
with name_scope_only_in_function_or_graph(
"update" if eagerly_outside_functions else "update_" +
var.op.name):
update_ops.append(distribute_ctx.get_replica_context()._update( # pylint: disable=protected-access
var, apply_grad_to_update_var, args=(grad,), group=False))
update_op = distribution.extended.update(
var, apply_grad_to_update_var, args=(grad,), group=False)
if distribute_ctx.in_cross_replica_context():
# In cross-replica context, extended.update returns a list of
# update ops from all replicas (group=False).
update_ops.extend(update_op)
else:
# In replica context, extended.update return the single update op
# of current replica.
update_ops.append(update_op)
any_symbolic = any(isinstance(i, ops.Operation) or
tf_utils.is_symbolic_tensor(i) for i in update_ops)