[pruning] add rowwise counter to sparse adagrad

Summary: Use the newly added counter op in sparse adagrad

Reviewed By: chocjy, ellie-wen

Differential Revision: D19221100

fbshipit-source-id: d939d83e3b5b3179f57194be2e8864d0fbbee2c1
This commit is contained in:
Rui Liu 2020-06-30 14:34:09 -07:00 committed by Facebook GitHub Bot
parent 40e79bb1d3
commit 9d8dc0318b
2 changed files with 81 additions and 1 deletions

View File

@ -520,7 +520,9 @@ class AdagradOptimizer(Optimizer):
sparse_dedup_aggregator=None, rowWise=False, engine='',
lars=None, output_effective_lr=False,
output_effective_lr_and_update=False,
pruning_options=None, swa_options=None, weight_scale=None, **kwargs):
pruning_options=None, swa_options=None, weight_scale=None,
counter_halflife=-1,
**kwargs):
for k, v in locals().items():
logger.info('AdagradOptimizer: input arguments: {}: {}'.format(k, v))
@ -536,6 +538,7 @@ class AdagradOptimizer(Optimizer):
self.lars = lars
self.output_effective_lr = output_effective_lr
self.output_effective_lr_and_update = output_effective_lr_and_update
self.counter_halflife = counter_halflife
self.init_kwargs = kwargs
self.weight_scale = weight_scale
@ -628,6 +631,9 @@ class AdagradOptimizer(Optimizer):
policy=self.policy,
**(self.init_kwargs)
)
iteration = lr_iteration
if self.counter_halflife > 0:
self._aux_params.shared.append(iteration)
if self.rowWise:
logger.info(
@ -730,6 +736,44 @@ class AdagradOptimizer(Optimizer):
"a delay iter needs to be provided")
self._aux_params.local.append(param_squared_sum)
if self.counter_halflife > 0:
shapes, types = workspace.InferShapesAndTypes([param_init_net])
if str(param) not in shapes:
shape = param_init_net.Shape(param, str(param) + "_shape")
num_rows = param_init_net.Slice(
[shape],
str(shape) + "_numrows",
starts=[0], ends=[1]
)
update_counter = param_init_net.ConstantFill(
num_rows,
str(param) + "_update_counter",
input_as_shape=1,
value=0.0,
)
prev_update_iter = param_init_net.ConstantFill(
num_rows,
str(param) + "_prev_update_iter",
input_as_shape=1,
value=0,
dtype=core.DataType.INT64,
)
else:
update_counter = param_init_net.ConstantFill(
[],
str(param) + "_update_counter",
shape=[shapes[str(param)][0]],
value=0.0,
)
prev_update_iter = param_init_net.ConstantFill(
[],
str(param) + "_prev_update_iter",
shape=[shapes[str(param)][0]],
value=0,
dtype=core.DataType.INT64,
)
self._aux_params.local.append(update_counter)
self._aux_params.local.append(prev_update_iter)
if self.rowWise:
assert isinstance(grad, core.GradientSlice),\
@ -801,6 +845,12 @@ class AdagradOptimizer(Optimizer):
epsilon=self.epsilon,
engine=self.engine,
)
if self.counter_halflife > 0:
net.RowWiseCounter(
[prev_update_iter, update_counter, grad.indices, iteration],
[prev_update_iter, update_counter],
counter_halflife=self.counter_halflife,
)
else:
input_args = [param, param_squared_sum, grad, lr]
output_args = [param, param_squared_sum]

View File

@ -156,6 +156,36 @@ class TestRowWiseAdagrad(OptimizerTestBase, TestCase):
def testGPUDense(self):
raise unittest.SkipTest("no dense support")
class TestRowWiseAdagradWithCounter(OptimizerTestBase, TestCase):
def build_optimizer(self, model, **kwargs):
self._skip_gpu = True
return build_adagrad(
model,
base_learning_rate=1.0,
lars=0.5,
rowWise=True,
counter_halflife=5,
**kwargs
)
def check_optimizer(self, optimizer):
self.assertTrue(optimizer.get_auxiliary_parameters().shared)
self.assertTrue(optimizer.get_auxiliary_parameters().local)
self.assertTrue(workspace.HasBlob("optimizer_iteration"))
iteration_tensor = workspace.FetchBlob("optimizer_iteration")
np.testing.assert_allclose(np.array([2000]),
iteration_tensor,
atol=1e-5)
for param in optimizer.get_auxiliary_parameters().shared:
workspace.FetchBlob(param)
for param in optimizer.get_auxiliary_parameters().local:
workspace.FetchBlob(param)
def testDense(self):
raise unittest.SkipTest("no dense support")
def testGPUDense(self):
raise unittest.SkipTest("no dense support")
class TestWngrad(OptimizerTestBase, LRModificationTestBase, TestCase):
def build_optimizer(self, model, **kwargs):