mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pruning] add rowwise counter to sparse adagrad
Summary: Use the newly added counter op in sparse adagrad Reviewed By: chocjy, ellie-wen Differential Revision: D19221100 fbshipit-source-id: d939d83e3b5b3179f57194be2e8864d0fbbee2c1
This commit is contained in:
parent
40e79bb1d3
commit
9d8dc0318b
|
|
@ -520,7 +520,9 @@ class AdagradOptimizer(Optimizer):
|
|||
sparse_dedup_aggregator=None, rowWise=False, engine='',
|
||||
lars=None, output_effective_lr=False,
|
||||
output_effective_lr_and_update=False,
|
||||
pruning_options=None, swa_options=None, weight_scale=None, **kwargs):
|
||||
pruning_options=None, swa_options=None, weight_scale=None,
|
||||
counter_halflife=-1,
|
||||
**kwargs):
|
||||
for k, v in locals().items():
|
||||
logger.info('AdagradOptimizer: input arguments: {}: {}'.format(k, v))
|
||||
|
||||
|
|
@ -536,6 +538,7 @@ class AdagradOptimizer(Optimizer):
|
|||
self.lars = lars
|
||||
self.output_effective_lr = output_effective_lr
|
||||
self.output_effective_lr_and_update = output_effective_lr_and_update
|
||||
self.counter_halflife = counter_halflife
|
||||
self.init_kwargs = kwargs
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
|
|
@ -628,6 +631,9 @@ class AdagradOptimizer(Optimizer):
|
|||
policy=self.policy,
|
||||
**(self.init_kwargs)
|
||||
)
|
||||
iteration = lr_iteration
|
||||
if self.counter_halflife > 0:
|
||||
self._aux_params.shared.append(iteration)
|
||||
|
||||
if self.rowWise:
|
||||
logger.info(
|
||||
|
|
@ -730,6 +736,44 @@ class AdagradOptimizer(Optimizer):
|
|||
"a delay iter needs to be provided")
|
||||
|
||||
self._aux_params.local.append(param_squared_sum)
|
||||
if self.counter_halflife > 0:
|
||||
shapes, types = workspace.InferShapesAndTypes([param_init_net])
|
||||
if str(param) not in shapes:
|
||||
shape = param_init_net.Shape(param, str(param) + "_shape")
|
||||
num_rows = param_init_net.Slice(
|
||||
[shape],
|
||||
str(shape) + "_numrows",
|
||||
starts=[0], ends=[1]
|
||||
)
|
||||
update_counter = param_init_net.ConstantFill(
|
||||
num_rows,
|
||||
str(param) + "_update_counter",
|
||||
input_as_shape=1,
|
||||
value=0.0,
|
||||
)
|
||||
prev_update_iter = param_init_net.ConstantFill(
|
||||
num_rows,
|
||||
str(param) + "_prev_update_iter",
|
||||
input_as_shape=1,
|
||||
value=0,
|
||||
dtype=core.DataType.INT64,
|
||||
)
|
||||
else:
|
||||
update_counter = param_init_net.ConstantFill(
|
||||
[],
|
||||
str(param) + "_update_counter",
|
||||
shape=[shapes[str(param)][0]],
|
||||
value=0.0,
|
||||
)
|
||||
prev_update_iter = param_init_net.ConstantFill(
|
||||
[],
|
||||
str(param) + "_prev_update_iter",
|
||||
shape=[shapes[str(param)][0]],
|
||||
value=0,
|
||||
dtype=core.DataType.INT64,
|
||||
)
|
||||
self._aux_params.local.append(update_counter)
|
||||
self._aux_params.local.append(prev_update_iter)
|
||||
|
||||
if self.rowWise:
|
||||
assert isinstance(grad, core.GradientSlice),\
|
||||
|
|
@ -801,6 +845,12 @@ class AdagradOptimizer(Optimizer):
|
|||
epsilon=self.epsilon,
|
||||
engine=self.engine,
|
||||
)
|
||||
if self.counter_halflife > 0:
|
||||
net.RowWiseCounter(
|
||||
[prev_update_iter, update_counter, grad.indices, iteration],
|
||||
[prev_update_iter, update_counter],
|
||||
counter_halflife=self.counter_halflife,
|
||||
)
|
||||
else:
|
||||
input_args = [param, param_squared_sum, grad, lr]
|
||||
output_args = [param, param_squared_sum]
|
||||
|
|
|
|||
|
|
@ -156,6 +156,36 @@ class TestRowWiseAdagrad(OptimizerTestBase, TestCase):
|
|||
def testGPUDense(self):
|
||||
raise unittest.SkipTest("no dense support")
|
||||
|
||||
class TestRowWiseAdagradWithCounter(OptimizerTestBase, TestCase):
|
||||
def build_optimizer(self, model, **kwargs):
|
||||
self._skip_gpu = True
|
||||
return build_adagrad(
|
||||
model,
|
||||
base_learning_rate=1.0,
|
||||
lars=0.5,
|
||||
rowWise=True,
|
||||
counter_halflife=5,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def check_optimizer(self, optimizer):
|
||||
self.assertTrue(optimizer.get_auxiliary_parameters().shared)
|
||||
self.assertTrue(optimizer.get_auxiliary_parameters().local)
|
||||
self.assertTrue(workspace.HasBlob("optimizer_iteration"))
|
||||
iteration_tensor = workspace.FetchBlob("optimizer_iteration")
|
||||
np.testing.assert_allclose(np.array([2000]),
|
||||
iteration_tensor,
|
||||
atol=1e-5)
|
||||
for param in optimizer.get_auxiliary_parameters().shared:
|
||||
workspace.FetchBlob(param)
|
||||
for param in optimizer.get_auxiliary_parameters().local:
|
||||
workspace.FetchBlob(param)
|
||||
|
||||
def testDense(self):
|
||||
raise unittest.SkipTest("no dense support")
|
||||
|
||||
def testGPUDense(self):
|
||||
raise unittest.SkipTest("no dense support")
|
||||
|
||||
class TestWngrad(OptimizerTestBase, LRModificationTestBase, TestCase):
|
||||
def build_optimizer(self, model, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user