mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[pruning] add rowwise counter to sparse adagrad
Summary: Use the newly added counter op in sparse adagrad Reviewed By: chocjy, ellie-wen Differential Revision: D19221100 fbshipit-source-id: d939d83e3b5b3179f57194be2e8864d0fbbee2c1
This commit is contained in:
parent
40e79bb1d3
commit
9d8dc0318b
|
|
@ -520,7 +520,9 @@ class AdagradOptimizer(Optimizer):
|
||||||
sparse_dedup_aggregator=None, rowWise=False, engine='',
|
sparse_dedup_aggregator=None, rowWise=False, engine='',
|
||||||
lars=None, output_effective_lr=False,
|
lars=None, output_effective_lr=False,
|
||||||
output_effective_lr_and_update=False,
|
output_effective_lr_and_update=False,
|
||||||
pruning_options=None, swa_options=None, weight_scale=None, **kwargs):
|
pruning_options=None, swa_options=None, weight_scale=None,
|
||||||
|
counter_halflife=-1,
|
||||||
|
**kwargs):
|
||||||
for k, v in locals().items():
|
for k, v in locals().items():
|
||||||
logger.info('AdagradOptimizer: input arguments: {}: {}'.format(k, v))
|
logger.info('AdagradOptimizer: input arguments: {}: {}'.format(k, v))
|
||||||
|
|
||||||
|
|
@ -536,6 +538,7 @@ class AdagradOptimizer(Optimizer):
|
||||||
self.lars = lars
|
self.lars = lars
|
||||||
self.output_effective_lr = output_effective_lr
|
self.output_effective_lr = output_effective_lr
|
||||||
self.output_effective_lr_and_update = output_effective_lr_and_update
|
self.output_effective_lr_and_update = output_effective_lr_and_update
|
||||||
|
self.counter_halflife = counter_halflife
|
||||||
self.init_kwargs = kwargs
|
self.init_kwargs = kwargs
|
||||||
self.weight_scale = weight_scale
|
self.weight_scale = weight_scale
|
||||||
|
|
||||||
|
|
@ -628,6 +631,9 @@ class AdagradOptimizer(Optimizer):
|
||||||
policy=self.policy,
|
policy=self.policy,
|
||||||
**(self.init_kwargs)
|
**(self.init_kwargs)
|
||||||
)
|
)
|
||||||
|
iteration = lr_iteration
|
||||||
|
if self.counter_halflife > 0:
|
||||||
|
self._aux_params.shared.append(iteration)
|
||||||
|
|
||||||
if self.rowWise:
|
if self.rowWise:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -730,6 +736,44 @@ class AdagradOptimizer(Optimizer):
|
||||||
"a delay iter needs to be provided")
|
"a delay iter needs to be provided")
|
||||||
|
|
||||||
self._aux_params.local.append(param_squared_sum)
|
self._aux_params.local.append(param_squared_sum)
|
||||||
|
if self.counter_halflife > 0:
|
||||||
|
shapes, types = workspace.InferShapesAndTypes([param_init_net])
|
||||||
|
if str(param) not in shapes:
|
||||||
|
shape = param_init_net.Shape(param, str(param) + "_shape")
|
||||||
|
num_rows = param_init_net.Slice(
|
||||||
|
[shape],
|
||||||
|
str(shape) + "_numrows",
|
||||||
|
starts=[0], ends=[1]
|
||||||
|
)
|
||||||
|
update_counter = param_init_net.ConstantFill(
|
||||||
|
num_rows,
|
||||||
|
str(param) + "_update_counter",
|
||||||
|
input_as_shape=1,
|
||||||
|
value=0.0,
|
||||||
|
)
|
||||||
|
prev_update_iter = param_init_net.ConstantFill(
|
||||||
|
num_rows,
|
||||||
|
str(param) + "_prev_update_iter",
|
||||||
|
input_as_shape=1,
|
||||||
|
value=0,
|
||||||
|
dtype=core.DataType.INT64,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
update_counter = param_init_net.ConstantFill(
|
||||||
|
[],
|
||||||
|
str(param) + "_update_counter",
|
||||||
|
shape=[shapes[str(param)][0]],
|
||||||
|
value=0.0,
|
||||||
|
)
|
||||||
|
prev_update_iter = param_init_net.ConstantFill(
|
||||||
|
[],
|
||||||
|
str(param) + "_prev_update_iter",
|
||||||
|
shape=[shapes[str(param)][0]],
|
||||||
|
value=0,
|
||||||
|
dtype=core.DataType.INT64,
|
||||||
|
)
|
||||||
|
self._aux_params.local.append(update_counter)
|
||||||
|
self._aux_params.local.append(prev_update_iter)
|
||||||
|
|
||||||
if self.rowWise:
|
if self.rowWise:
|
||||||
assert isinstance(grad, core.GradientSlice),\
|
assert isinstance(grad, core.GradientSlice),\
|
||||||
|
|
@ -801,6 +845,12 @@ class AdagradOptimizer(Optimizer):
|
||||||
epsilon=self.epsilon,
|
epsilon=self.epsilon,
|
||||||
engine=self.engine,
|
engine=self.engine,
|
||||||
)
|
)
|
||||||
|
if self.counter_halflife > 0:
|
||||||
|
net.RowWiseCounter(
|
||||||
|
[prev_update_iter, update_counter, grad.indices, iteration],
|
||||||
|
[prev_update_iter, update_counter],
|
||||||
|
counter_halflife=self.counter_halflife,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_args = [param, param_squared_sum, grad, lr]
|
input_args = [param, param_squared_sum, grad, lr]
|
||||||
output_args = [param, param_squared_sum]
|
output_args = [param, param_squared_sum]
|
||||||
|
|
|
||||||
|
|
@ -156,6 +156,36 @@ class TestRowWiseAdagrad(OptimizerTestBase, TestCase):
|
||||||
def testGPUDense(self):
|
def testGPUDense(self):
|
||||||
raise unittest.SkipTest("no dense support")
|
raise unittest.SkipTest("no dense support")
|
||||||
|
|
||||||
|
class TestRowWiseAdagradWithCounter(OptimizerTestBase, TestCase):
|
||||||
|
def build_optimizer(self, model, **kwargs):
|
||||||
|
self._skip_gpu = True
|
||||||
|
return build_adagrad(
|
||||||
|
model,
|
||||||
|
base_learning_rate=1.0,
|
||||||
|
lars=0.5,
|
||||||
|
rowWise=True,
|
||||||
|
counter_halflife=5,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_optimizer(self, optimizer):
|
||||||
|
self.assertTrue(optimizer.get_auxiliary_parameters().shared)
|
||||||
|
self.assertTrue(optimizer.get_auxiliary_parameters().local)
|
||||||
|
self.assertTrue(workspace.HasBlob("optimizer_iteration"))
|
||||||
|
iteration_tensor = workspace.FetchBlob("optimizer_iteration")
|
||||||
|
np.testing.assert_allclose(np.array([2000]),
|
||||||
|
iteration_tensor,
|
||||||
|
atol=1e-5)
|
||||||
|
for param in optimizer.get_auxiliary_parameters().shared:
|
||||||
|
workspace.FetchBlob(param)
|
||||||
|
for param in optimizer.get_auxiliary_parameters().local:
|
||||||
|
workspace.FetchBlob(param)
|
||||||
|
|
||||||
|
def testDense(self):
|
||||||
|
raise unittest.SkipTest("no dense support")
|
||||||
|
|
||||||
|
def testGPUDense(self):
|
||||||
|
raise unittest.SkipTest("no dense support")
|
||||||
|
|
||||||
class TestWngrad(OptimizerTestBase, LRModificationTestBase, TestCase):
|
class TestWngrad(OptimizerTestBase, LRModificationTestBase, TestCase):
|
||||||
def build_optimizer(self, model, **kwargs):
|
def build_optimizer(self, model, **kwargs):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user