diff --git a/caffe2/python/optimizer.py b/caffe2/python/optimizer.py index ae5d9b0b55a..5ff90461e2e 100644 --- a/caffe2/python/optimizer.py +++ b/caffe2/python/optimizer.py @@ -520,7 +520,9 @@ class AdagradOptimizer(Optimizer): sparse_dedup_aggregator=None, rowWise=False, engine='', lars=None, output_effective_lr=False, output_effective_lr_and_update=False, - pruning_options=None, swa_options=None, weight_scale=None, **kwargs): + pruning_options=None, swa_options=None, weight_scale=None, + counter_halflife=-1, + **kwargs): for k, v in locals().items(): logger.info('AdagradOptimizer: input arguments: {}: {}'.format(k, v)) @@ -536,6 +538,7 @@ class AdagradOptimizer(Optimizer): self.lars = lars self.output_effective_lr = output_effective_lr self.output_effective_lr_and_update = output_effective_lr_and_update + self.counter_halflife = counter_halflife self.init_kwargs = kwargs self.weight_scale = weight_scale @@ -628,6 +631,9 @@ class AdagradOptimizer(Optimizer): policy=self.policy, **(self.init_kwargs) ) + iteration = lr_iteration + if self.counter_halflife > 0: + self._aux_params.shared.append(iteration) if self.rowWise: logger.info( @@ -730,6 +736,44 @@ class AdagradOptimizer(Optimizer): "a delay iter needs to be provided") self._aux_params.local.append(param_squared_sum) + if self.counter_halflife > 0: + shapes, types = workspace.InferShapesAndTypes([param_init_net]) + if str(param) not in shapes: + shape = param_init_net.Shape(param, str(param) + "_shape") + num_rows = param_init_net.Slice( + [shape], + str(shape) + "_numrows", + starts=[0], ends=[1] + ) + update_counter = param_init_net.ConstantFill( + num_rows, + str(param) + "_update_counter", + input_as_shape=1, + value=0.0, + ) + prev_update_iter = param_init_net.ConstantFill( + num_rows, + str(param) + "_prev_update_iter", + input_as_shape=1, + value=0, + dtype=core.DataType.INT64, + ) + else: + update_counter = param_init_net.ConstantFill( + [], + str(param) + "_update_counter", + shape=[shapes[str(param)][0]], + value=0.0, + ) + prev_update_iter = param_init_net.ConstantFill( + [], + str(param) + "_prev_update_iter", + shape=[shapes[str(param)][0]], + value=0, + dtype=core.DataType.INT64, + ) + self._aux_params.local.append(update_counter) + self._aux_params.local.append(prev_update_iter) if self.rowWise: assert isinstance(grad, core.GradientSlice),\ @@ -801,6 +845,12 @@ class AdagradOptimizer(Optimizer): epsilon=self.epsilon, engine=self.engine, ) + if self.counter_halflife > 0: + net.RowWiseCounter( + [prev_update_iter, update_counter, grad.indices, iteration], + [prev_update_iter, update_counter], + counter_halflife=self.counter_halflife, + ) else: input_args = [param, param_squared_sum, grad, lr] output_args = [param, param_squared_sum] diff --git a/caffe2/python/optimizer_test.py b/caffe2/python/optimizer_test.py index b2a5e58ed70..a45571f1968 100644 --- a/caffe2/python/optimizer_test.py +++ b/caffe2/python/optimizer_test.py @@ -156,6 +156,36 @@ class TestRowWiseAdagrad(OptimizerTestBase, TestCase): def testGPUDense(self): raise unittest.SkipTest("no dense support") +class TestRowWiseAdagradWithCounter(OptimizerTestBase, TestCase): + def build_optimizer(self, model, **kwargs): + self._skip_gpu = True + return build_adagrad( + model, + base_learning_rate=1.0, + lars=0.5, + rowWise=True, + counter_halflife=5, + **kwargs + ) + + def check_optimizer(self, optimizer): + self.assertTrue(optimizer.get_auxiliary_parameters().shared) + self.assertTrue(optimizer.get_auxiliary_parameters().local) + self.assertTrue(workspace.HasBlob("optimizer_iteration")) + iteration_tensor = workspace.FetchBlob("optimizer_iteration") + np.testing.assert_allclose(np.array([2000]), + iteration_tensor, + atol=1e-5) + for param in optimizer.get_auxiliary_parameters().shared: + workspace.FetchBlob(param) + for param in optimizer.get_auxiliary_parameters().local: + workspace.FetchBlob(param) + + def testDense(self): + raise unittest.SkipTest("no dense support") + + def testGPUDense(self): + raise unittest.SkipTest("no dense support") class TestWngrad(OptimizerTestBase, LRModificationTestBase, TestCase): def build_optimizer(self, model, **kwargs):