mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Make sure that Adam colocates ops with a consistent variable across workers.
PiperOrigin-RevId: 158022292
This commit is contained in:
parent
69ba4d3d49
commit
504a307b74
|
|
@ -113,10 +113,14 @@ class AdamOptimizer(optimizer.Optimizer):
|
||||||
|
|
||||||
def _create_slots(self, var_list):
|
def _create_slots(self, var_list):
|
||||||
# Create the beta1 and beta2 accumulators on the same device as the first
|
# Create the beta1 and beta2 accumulators on the same device as the first
|
||||||
# variable.
|
# variable. Sort the var_list to make sure this device is consistent across
|
||||||
|
# workers (these need to go on the same PS, otherwise some updates are
|
||||||
|
# silently ignored).
|
||||||
|
first_var = min(var_list, key=lambda x: x.name)
|
||||||
|
|
||||||
if (self._beta1_power is None or
|
if (self._beta1_power is None or
|
||||||
self._beta1_power.graph is not var_list[0].graph):
|
self._beta1_power.graph is not first_var.graph):
|
||||||
with ops.colocate_with(var_list[0]):
|
with ops.colocate_with(first_var):
|
||||||
self._beta1_power = variable_scope.variable(self._beta1,
|
self._beta1_power = variable_scope.variable(self._beta1,
|
||||||
name="beta1_power",
|
name="beta1_power",
|
||||||
trainable=False)
|
trainable=False)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user