mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Smart Decay for Adam - DPER3 (#62058)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62058 This is the second diff in this stack. This diff includes the changes to DPER3; the first diff includes the changes to Caffe2. We want to decay learning parameters properly. Previously this was not done when a parameter is absent from a minibatch. We fix this by keeping track of missed minibatches and making decay catch up accordingly. The exponential moving averages (EMA) for the first and second moments used in Adam are updated only for parameters seen in a minibatch. Actually, for these parameters, 0 should be added to the EMAs and the EMAs should then be decayed by multiplying by beta1 and beta2 respectively. To avoid the computational overhead of touching every parameter for every minibatch, we: * keep track of the last time a parameter is seen * instead of decaying the EMAs by multiplying by beta1 and beta2, we multiply by beta1^k and beta2^k, where k is the number of minibatches since the parameter was last seen. We hope this will significantly improve the inconsistent learning parameter issue we have seen with Adam. Differential Revision: D29638897 fbshipit-source-id: 18d8e227d72c2e23010ca81e0f6eeb78872c8d3c
This commit is contained in:
parent
5224490ae9
commit
812bc1dde6
|
|
@ -1514,6 +1514,7 @@ class AdamOptimizer(Optimizer):
|
|||
rowWise=False,
|
||||
engine="",
|
||||
enableRAdam=False,
|
||||
use_smart_decay=False, # See https://fburl.com/2jdiwrhy for context.
|
||||
**kwargs
|
||||
):
|
||||
super(AdamOptimizer, self).__init__()
|
||||
|
|
@ -1529,6 +1530,18 @@ class AdamOptimizer(Optimizer):
|
|||
self.rowWise = rowWise
|
||||
self.engine = engine
|
||||
self.enableRAdam = enableRAdam
|
||||
if use_smart_decay:
|
||||
if rowWise:
|
||||
raise NotImplementedError(('Smart decay is not implemented for rowWise Adam. '
|
||||
'Set rowWise or use_smart_decay to False.'))
|
||||
if enableRAdam:
|
||||
raise NotImplementedError(('Smart decay is not implemented for RAdam. '
|
||||
'Set enableRAdam or use_smart_decay to False.'))
|
||||
if use_lr_adaption:
|
||||
raise NotImplementedError(('Smart decay is not implemented with lr_adaption. '
|
||||
'Set use_lr_adaption or use_smart_decay to False.'))
|
||||
|
||||
self.use_smart_decay = use_smart_decay
|
||||
self.init_kwargs = kwargs
|
||||
|
||||
def _run(self, net, param_init_net, param_info):
|
||||
|
|
@ -1558,6 +1571,14 @@ class AdamOptimizer(Optimizer):
|
|||
[param], param + "_second_moment", value=0.0
|
||||
)
|
||||
|
||||
# Initialize "minibatch in which this parameter was last seen" for smart decay.
|
||||
if self.use_smart_decay:
|
||||
shapes, _ = workspace.InferShapesAndTypes([param_init_net])
|
||||
last_seen = param_init_net.ConstantFill(
|
||||
[], param + "_last_seen", shape=[shapes[param][0]], value=0, dtype=core.DataType.INT64
|
||||
)
|
||||
self._aux_params.local.append(last_seen)
|
||||
|
||||
self._aux_params.shared.append(iteration)
|
||||
self._aux_params.local.append(m1)
|
||||
self._aux_params.local.append(m2)
|
||||
|
|
@ -1570,6 +1591,10 @@ class AdamOptimizer(Optimizer):
|
|||
)
|
||||
|
||||
output_blobs = [param, m1, m2]
|
||||
|
||||
if self.use_smart_decay:
|
||||
output_blobs.append(last_seen)
|
||||
|
||||
if self.use_lr_adaption:
|
||||
effective_grad = str(param) + "_effective_grad"
|
||||
output_blobs.append(effective_grad)
|
||||
|
|
@ -1578,6 +1603,8 @@ class AdamOptimizer(Optimizer):
|
|||
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
|
||||
if self.rowWise:
|
||||
op = "RowWiseSparseAdam"
|
||||
elif self.use_smart_decay:
|
||||
op = "SmartDecaySparseAdam"
|
||||
else:
|
||||
op = "SparseAdam"
|
||||
|
||||
|
|
@ -1591,6 +1618,14 @@ class AdamOptimizer(Optimizer):
|
|||
epsilon=self.epsilon,
|
||||
enableRAdam=self.enableRAdam,
|
||||
)
|
||||
elif op == "SmartDecaySparseAdam":
|
||||
net.__getattr__(op)(
|
||||
[param, m1, m2, last_seen, grad.indices, grad.values, lr, iteration],
|
||||
output_blobs,
|
||||
beta1=self.beta1,
|
||||
beta2=self.beta2,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not self.enableRAdam
|
||||
|
|
|
|||
|
|
@ -241,6 +241,23 @@ class TestAdam(OptimizerTestBase, LRModificationTestBase, TestCase):
|
|||
for param in optimizer.get_auxiliary_parameters().local:
|
||||
workspace.FetchBlob(param)
|
||||
|
||||
class TestSmartDecayAdam(OptimizerTestBase, LRModificationTestBase, TestCase):
|
||||
def build_optimizer(self, model, **kwargs):
|
||||
self._skip_gpu = False
|
||||
kwargs['beta1'] = 0.0
|
||||
return build_adam(model, base_learning_rate=0.1, use_smart_decay=True, **kwargs)
|
||||
|
||||
def check_optimizer(self, optimizer):
|
||||
self.assertTrue(optimizer.get_auxiliary_parameters().shared)
|
||||
self.assertTrue(optimizer.get_auxiliary_parameters().local)
|
||||
self.assertTrue(workspace.HasBlob("optimizer_iteration"))
|
||||
blob_names = workspace.Blobs()
|
||||
self.assertTrue(any((bn.endswith('_last_seen') for bn in blob_names)))
|
||||
for param in optimizer.get_auxiliary_parameters().shared:
|
||||
workspace.FetchBlob(param)
|
||||
for param in optimizer.get_auxiliary_parameters().local:
|
||||
workspace.FetchBlob(param)
|
||||
|
||||
class TestDecayAdagrad(OptimizerTestBase, LRModificationTestBase, TestCase):
|
||||
def build_optimizer(self, model, **kwargs):
|
||||
self._skip_gpu = True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user