[LRD] Allowing using dedicated iteration counter for learning rate (#85195)

Summary: So that we could manipulate the iteration counter for lrarning rate separately (for learning rate decay or learning rate re-warming up etc), without affecting other techniques relying on iterations (such as EMA)

Test Plan:
Unit tests:
```
    ✓ Pass: caffe2/caffe2/python:optimizer_test - testSparse (caffe2.caffe2.python.optimizer_test.TestAdagradWithDedicatedLRIteration) (46.475)
    ✓ Pass: caffe2/caffe2/python:optimizer_test - test_global_norm_based_gradient_clipping (caffe2.caffe2.python.optimizer_test.TestAdagradWithDedicatedLRIteration) (46.475)
    ✓ Pass: caffe2/caffe2/python:optimizer_test - test_lr_injection (caffe2.caffe2.python.optimizer_test.TestAdagradWithDedicatedLRIteration) (46.475)
    ✓ Pass: caffe2/caffe2/python:optimizer_test - main (46.475)
Summary
  Pass: 5
  Skip: 1
    ↻ caffe2/caffe2/python:optimizer_test - testGPUDense (caffe2.caffe2.python.optimizer_test.TestAdagradWithDedicatedLRIteration)
  ListingSuccess: 1
```

Reviewed By: liangming168

Differential Revision: D38747417

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85195
Approved by: https://github.com/liangming168, https://github.com/eellison
This commit is contained in:
Wenguang Mao 2022-09-27 00:56:57 +00:00 committed by PyTorch MergeBot
parent 784f4ba1ce
commit 755b39ba66
3 changed files with 87 additions and 9 deletions

View File

@ -39,6 +39,7 @@ class Optimizer(object):
self._lr_multiplier = None
self._local_lr_multiplier = None
self._local_lr_multiplier_on_gpu = False
self._use_dedicated_lr_iteration_counter = False
"""
Adds optimization operators to the net for given parameter and its gradient
@ -86,6 +87,14 @@ class Optimizer(object):
del attr["_instance_num"]
return attr
@property
def use_dedicated_lr_iteration_counter(self):
return self._use_dedicated_lr_iteration_counter
@use_dedicated_lr_iteration_counter.setter
def use_dedicated_lr_iteration_counter(self, val):
self._use_dedicated_lr_iteration_counter = val
def make_unique_blob_name(self, base_str):
"""
Returns a blob name that will be unique to the current device
@ -115,7 +124,17 @@ class Optimizer(object):
if learning_rate_blob is None:
learning_rate_blob = self.make_unique_blob_name("lr")
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
if self._use_dedicated_lr_iteration_counter:
iteration = utils.BuildUniqueMutexIter(
param_init_net,
net,
iter=utils.OPTIMIZER_ITERATION_LR_NAME,
iter_mutex=utils.ITERATION_MUTEX_LR_NAME,
iter_val=iter_val,
)
logger.info(f"Created dedicated learning rate iteration counter: {iteration}")
else:
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
if not net.BlobIsDefined(learning_rate_blob):
# There is one interesting thing here: since we are minimizing, we are
@ -163,6 +182,36 @@ class Optimizer(object):
return lr, iteration
def build_non_lr_iter(
self,
net,
param_init_net,
iter_val=0,
):
assert (
self._use_dedicated_lr_iteration_counter
), "This method should be only called when dedicated learning rate iteration counter is used."
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
logger.info(f"Created iteration counter for non learning rate purposes: {iteration}")
# We need to create a dummy learning rate operator to enforce that
# iteration counter blob being placed in the trainer nodes. Otherwise,
# the Automatic Device Placement (ADP) algorithm for Hierachical
# Training (HT) will encounter issues to distribute blobs across group
# parameter servers. Note that this learning rate operator will not be
# used for any other purpose.
learning_rate_blob = self.make_unique_blob_name("iter_placement_hint")
if not net.BlobIsDefined(learning_rate_blob):
net.LearningRate(
[iteration],
learning_rate_blob,
base_lr=1.0,
policy="fixed",
)
return iteration
def add_lr_multiplier(self, lr_multiplier):
"""
Set the global learning rate multiplier. If a multiplier already
@ -582,6 +631,7 @@ class AdagradOptimizer(Optimizer):
ema_options=None,
weight_scale=None,
counter_halflife=-1,
use_dedicated_lr_iteration_counter=False,
**kwargs
):
super(AdagradOptimizer, self).__init__()
@ -599,6 +649,7 @@ class AdagradOptimizer(Optimizer):
self.counter_halflife = counter_halflife
self.init_kwargs = kwargs
self.weight_scale = weight_scale
self.use_dedicated_lr_iteration_counter = use_dedicated_lr_iteration_counter
self._process_pruning_options(pruning_options)
self._process_swa_options(swa_options)
@ -727,7 +778,12 @@ class AdagradOptimizer(Optimizer):
policy=self.policy,
**(self.init_kwargs)
)
iteration = lr_iteration
iteration = (
self.build_non_lr_iter(net, param_init_net, iter_val=0)
if self._use_dedicated_lr_iteration_counter
else lr_iteration
)
if self.counter_halflife > 0:
self._aux_params.shared.append(iteration)
@ -970,7 +1026,7 @@ class AdagradOptimizer(Optimizer):
logger.debug("using {} for {}".format(op, str(param)))
if self.prune_delays:
input_args += [lr_iteration, last_mask_updated_iter]
input_args += [iteration, last_mask_updated_iter]
output_args += [mask_blob, last_mask_updated_iter]
if weight_decay > 0 and self.counter_halflife == -1:
@ -1020,7 +1076,7 @@ class AdagradOptimizer(Optimizer):
input_args += [mask_blob]
if self.prune_delays:
input_args += [lr_iteration, last_mask_updated_iter]
input_args += [iteration, last_mask_updated_iter]
output_args += [mask_blob, last_mask_updated_iter]
if self.use_mask:
@ -1063,7 +1119,7 @@ class AdagradOptimizer(Optimizer):
self._aux_params.local.append(param_swa)
net.SWA(
[param, param_swa, lr_iteration],
[param, param_swa, iteration],
[param, param_swa],
avg_start=self.swa_avg_start_it,
avg_end=self.swa_avg_end_it,
@ -1079,7 +1135,7 @@ class AdagradOptimizer(Optimizer):
self._aux_params.local.append(param_ema)
net.EMA(
[param, param_ema, lr_iteration],
[param, param_ema, iteration],
[param, param_ema],
ema_start=self.ema_start,
ema_end=self.ema_end,
@ -1089,7 +1145,7 @@ class AdagradOptimizer(Optimizer):
if self.weight_scale:
net.WeightScale(
[param, lr_iteration],
[param, iteration],
[param],
stepsize=self.weight_scale.stepsize,
upper_bound_iter=self.weight_scale.upper_bound_iter,
@ -1097,7 +1153,7 @@ class AdagradOptimizer(Optimizer):
)
if self.weight_scale.to_aux:
net.WeightScale(
[param_squared_sum, lr_iteration],
[param_squared_sum, iteration],
[param_squared_sum],
stepsize=self.weight_scale.stepsize,
upper_bound_iter=self.weight_scale.upper_bound_iter,

View File

@ -11,7 +11,7 @@ from caffe2.python.optimizer_context import UseOptimizer
from caffe2.python.optimizer_test_util import (
OptimizerTestBase, LRModificationTestBase
)
from caffe2.python import core, workspace
from caffe2.python import core, utils, workspace
from caffe2.python.test_util import TestCase
import numpy as np
from numpy.testing import assert_allclose, assert_equal
@ -137,6 +137,26 @@ class TestAdagrad(OptimizerTestBase, LRModificationTestBase, TestCase):
workspace.FetchBlob(param)
class TestAdagradWithDedicatedLRIteration(OptimizerTestBase, LRModificationTestBase, TestCase):
def build_optimizer(self, model, **kwargs):
self._skip_gpu = False
return build_adagrad(model, base_learning_rate=1.0, lars=0.5, use_dedicated_lr_iteration_counter=True, **kwargs)
def check_optimizer(self, optimizer):
self.assertFalse(optimizer.get_auxiliary_parameters().shared)
self.assertTrue(optimizer.get_auxiliary_parameters().local)
for param in optimizer.get_auxiliary_parameters().local:
workspace.FetchBlob(param)
# check iteration counters have the same value by default
non_lr_iter = workspace.FetchBlob(utils.OPTIMIZER_ITERATION_NAME)
lr_iter = workspace.FetchBlob(utils.OPTIMIZER_ITERATION_LR_NAME)
self.assertEqual(non_lr_iter, lr_iter)
def testGPUDense(self):
raise unittest.SkipTest("GPU support is not validated")
class TestRowWiseAdagrad(OptimizerTestBase, TestCase):
def build_optimizer(self, model, **kwargs):
self._skip_gpu = True

View File

@ -18,7 +18,9 @@ import numpy as np
from six import integer_types, binary_type, text_type, string_types
OPTIMIZER_ITERATION_NAME = "optimizer_iteration"
OPTIMIZER_ITERATION_LR_NAME = "optimizer_iteration_lr"
ITERATION_MUTEX_NAME = "iteration_mutex"
ITERATION_MUTEX_LR_NAME = "iteration_mutex_lr"
def OpAlmostEqual(op_a, op_b, ignore_fields=None):