mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[LRD] Allowing using dedicated iteration counter for learning rate (#85195)
Summary: So that we could manipulate the iteration counter for lrarning rate separately (for learning rate decay or learning rate re-warming up etc), without affecting other techniques relying on iterations (such as EMA)
Test Plan:
Unit tests:
```
✓ Pass: caffe2/caffe2/python:optimizer_test - testSparse (caffe2.caffe2.python.optimizer_test.TestAdagradWithDedicatedLRIteration) (46.475)
✓ Pass: caffe2/caffe2/python:optimizer_test - test_global_norm_based_gradient_clipping (caffe2.caffe2.python.optimizer_test.TestAdagradWithDedicatedLRIteration) (46.475)
✓ Pass: caffe2/caffe2/python:optimizer_test - test_lr_injection (caffe2.caffe2.python.optimizer_test.TestAdagradWithDedicatedLRIteration) (46.475)
✓ Pass: caffe2/caffe2/python:optimizer_test - main (46.475)
Summary
Pass: 5
Skip: 1
↻ caffe2/caffe2/python:optimizer_test - testGPUDense (caffe2.caffe2.python.optimizer_test.TestAdagradWithDedicatedLRIteration)
ListingSuccess: 1
```
Reviewed By: liangming168
Differential Revision: D38747417
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85195
Approved by: https://github.com/liangming168, https://github.com/eellison
This commit is contained in:
parent
784f4ba1ce
commit
755b39ba66
|
|
@ -39,6 +39,7 @@ class Optimizer(object):
|
|||
self._lr_multiplier = None
|
||||
self._local_lr_multiplier = None
|
||||
self._local_lr_multiplier_on_gpu = False
|
||||
self._use_dedicated_lr_iteration_counter = False
|
||||
|
||||
"""
|
||||
Adds optimization operators to the net for given parameter and its gradient
|
||||
|
|
@ -86,6 +87,14 @@ class Optimizer(object):
|
|||
del attr["_instance_num"]
|
||||
return attr
|
||||
|
||||
@property
|
||||
def use_dedicated_lr_iteration_counter(self):
|
||||
return self._use_dedicated_lr_iteration_counter
|
||||
|
||||
@use_dedicated_lr_iteration_counter.setter
|
||||
def use_dedicated_lr_iteration_counter(self, val):
|
||||
self._use_dedicated_lr_iteration_counter = val
|
||||
|
||||
def make_unique_blob_name(self, base_str):
|
||||
"""
|
||||
Returns a blob name that will be unique to the current device
|
||||
|
|
@ -115,7 +124,17 @@ class Optimizer(object):
|
|||
if learning_rate_blob is None:
|
||||
learning_rate_blob = self.make_unique_blob_name("lr")
|
||||
|
||||
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
|
||||
if self._use_dedicated_lr_iteration_counter:
|
||||
iteration = utils.BuildUniqueMutexIter(
|
||||
param_init_net,
|
||||
net,
|
||||
iter=utils.OPTIMIZER_ITERATION_LR_NAME,
|
||||
iter_mutex=utils.ITERATION_MUTEX_LR_NAME,
|
||||
iter_val=iter_val,
|
||||
)
|
||||
logger.info(f"Created dedicated learning rate iteration counter: {iteration}")
|
||||
else:
|
||||
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
|
||||
|
||||
if not net.BlobIsDefined(learning_rate_blob):
|
||||
# There is one interesting thing here: since we are minimizing, we are
|
||||
|
|
@ -163,6 +182,36 @@ class Optimizer(object):
|
|||
|
||||
return lr, iteration
|
||||
|
||||
def build_non_lr_iter(
|
||||
self,
|
||||
net,
|
||||
param_init_net,
|
||||
iter_val=0,
|
||||
):
|
||||
assert (
|
||||
self._use_dedicated_lr_iteration_counter
|
||||
), "This method should be only called when dedicated learning rate iteration counter is used."
|
||||
|
||||
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
|
||||
logger.info(f"Created iteration counter for non learning rate purposes: {iteration}")
|
||||
|
||||
# We need to create a dummy learning rate operator to enforce that
|
||||
# iteration counter blob being placed in the trainer nodes. Otherwise,
|
||||
# the Automatic Device Placement (ADP) algorithm for Hierachical
|
||||
# Training (HT) will encounter issues to distribute blobs across group
|
||||
# parameter servers. Note that this learning rate operator will not be
|
||||
# used for any other purpose.
|
||||
learning_rate_blob = self.make_unique_blob_name("iter_placement_hint")
|
||||
if not net.BlobIsDefined(learning_rate_blob):
|
||||
net.LearningRate(
|
||||
[iteration],
|
||||
learning_rate_blob,
|
||||
base_lr=1.0,
|
||||
policy="fixed",
|
||||
)
|
||||
|
||||
return iteration
|
||||
|
||||
def add_lr_multiplier(self, lr_multiplier):
|
||||
"""
|
||||
Set the global learning rate multiplier. If a multiplier already
|
||||
|
|
@ -582,6 +631,7 @@ class AdagradOptimizer(Optimizer):
|
|||
ema_options=None,
|
||||
weight_scale=None,
|
||||
counter_halflife=-1,
|
||||
use_dedicated_lr_iteration_counter=False,
|
||||
**kwargs
|
||||
):
|
||||
super(AdagradOptimizer, self).__init__()
|
||||
|
|
@ -599,6 +649,7 @@ class AdagradOptimizer(Optimizer):
|
|||
self.counter_halflife = counter_halflife
|
||||
self.init_kwargs = kwargs
|
||||
self.weight_scale = weight_scale
|
||||
self.use_dedicated_lr_iteration_counter = use_dedicated_lr_iteration_counter
|
||||
|
||||
self._process_pruning_options(pruning_options)
|
||||
self._process_swa_options(swa_options)
|
||||
|
|
@ -727,7 +778,12 @@ class AdagradOptimizer(Optimizer):
|
|||
policy=self.policy,
|
||||
**(self.init_kwargs)
|
||||
)
|
||||
iteration = lr_iteration
|
||||
iteration = (
|
||||
self.build_non_lr_iter(net, param_init_net, iter_val=0)
|
||||
if self._use_dedicated_lr_iteration_counter
|
||||
else lr_iteration
|
||||
)
|
||||
|
||||
if self.counter_halflife > 0:
|
||||
self._aux_params.shared.append(iteration)
|
||||
|
||||
|
|
@ -970,7 +1026,7 @@ class AdagradOptimizer(Optimizer):
|
|||
logger.debug("using {} for {}".format(op, str(param)))
|
||||
|
||||
if self.prune_delays:
|
||||
input_args += [lr_iteration, last_mask_updated_iter]
|
||||
input_args += [iteration, last_mask_updated_iter]
|
||||
output_args += [mask_blob, last_mask_updated_iter]
|
||||
|
||||
if weight_decay > 0 and self.counter_halflife == -1:
|
||||
|
|
@ -1020,7 +1076,7 @@ class AdagradOptimizer(Optimizer):
|
|||
input_args += [mask_blob]
|
||||
|
||||
if self.prune_delays:
|
||||
input_args += [lr_iteration, last_mask_updated_iter]
|
||||
input_args += [iteration, last_mask_updated_iter]
|
||||
output_args += [mask_blob, last_mask_updated_iter]
|
||||
|
||||
if self.use_mask:
|
||||
|
|
@ -1063,7 +1119,7 @@ class AdagradOptimizer(Optimizer):
|
|||
self._aux_params.local.append(param_swa)
|
||||
|
||||
net.SWA(
|
||||
[param, param_swa, lr_iteration],
|
||||
[param, param_swa, iteration],
|
||||
[param, param_swa],
|
||||
avg_start=self.swa_avg_start_it,
|
||||
avg_end=self.swa_avg_end_it,
|
||||
|
|
@ -1079,7 +1135,7 @@ class AdagradOptimizer(Optimizer):
|
|||
self._aux_params.local.append(param_ema)
|
||||
|
||||
net.EMA(
|
||||
[param, param_ema, lr_iteration],
|
||||
[param, param_ema, iteration],
|
||||
[param, param_ema],
|
||||
ema_start=self.ema_start,
|
||||
ema_end=self.ema_end,
|
||||
|
|
@ -1089,7 +1145,7 @@ class AdagradOptimizer(Optimizer):
|
|||
|
||||
if self.weight_scale:
|
||||
net.WeightScale(
|
||||
[param, lr_iteration],
|
||||
[param, iteration],
|
||||
[param],
|
||||
stepsize=self.weight_scale.stepsize,
|
||||
upper_bound_iter=self.weight_scale.upper_bound_iter,
|
||||
|
|
@ -1097,7 +1153,7 @@ class AdagradOptimizer(Optimizer):
|
|||
)
|
||||
if self.weight_scale.to_aux:
|
||||
net.WeightScale(
|
||||
[param_squared_sum, lr_iteration],
|
||||
[param_squared_sum, iteration],
|
||||
[param_squared_sum],
|
||||
stepsize=self.weight_scale.stepsize,
|
||||
upper_bound_iter=self.weight_scale.upper_bound_iter,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from caffe2.python.optimizer_context import UseOptimizer
|
|||
from caffe2.python.optimizer_test_util import (
|
||||
OptimizerTestBase, LRModificationTestBase
|
||||
)
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python import core, utils, workspace
|
||||
from caffe2.python.test_util import TestCase
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_equal
|
||||
|
|
@ -137,6 +137,26 @@ class TestAdagrad(OptimizerTestBase, LRModificationTestBase, TestCase):
|
|||
workspace.FetchBlob(param)
|
||||
|
||||
|
||||
class TestAdagradWithDedicatedLRIteration(OptimizerTestBase, LRModificationTestBase, TestCase):
|
||||
def build_optimizer(self, model, **kwargs):
|
||||
self._skip_gpu = False
|
||||
return build_adagrad(model, base_learning_rate=1.0, lars=0.5, use_dedicated_lr_iteration_counter=True, **kwargs)
|
||||
|
||||
def check_optimizer(self, optimizer):
|
||||
self.assertFalse(optimizer.get_auxiliary_parameters().shared)
|
||||
self.assertTrue(optimizer.get_auxiliary_parameters().local)
|
||||
for param in optimizer.get_auxiliary_parameters().local:
|
||||
workspace.FetchBlob(param)
|
||||
|
||||
# check iteration counters have the same value by default
|
||||
non_lr_iter = workspace.FetchBlob(utils.OPTIMIZER_ITERATION_NAME)
|
||||
lr_iter = workspace.FetchBlob(utils.OPTIMIZER_ITERATION_LR_NAME)
|
||||
self.assertEqual(non_lr_iter, lr_iter)
|
||||
|
||||
def testGPUDense(self):
|
||||
raise unittest.SkipTest("GPU support is not validated")
|
||||
|
||||
|
||||
class TestRowWiseAdagrad(OptimizerTestBase, TestCase):
|
||||
def build_optimizer(self, model, **kwargs):
|
||||
self._skip_gpu = True
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ import numpy as np
|
|||
from six import integer_types, binary_type, text_type, string_types
|
||||
|
||||
OPTIMIZER_ITERATION_NAME = "optimizer_iteration"
|
||||
OPTIMIZER_ITERATION_LR_NAME = "optimizer_iteration_lr"
|
||||
ITERATION_MUTEX_NAME = "iteration_mutex"
|
||||
ITERATION_MUTEX_LR_NAME = "iteration_mutex_lr"
|
||||
|
||||
|
||||
def OpAlmostEqual(op_a, op_b, ignore_fields=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user