mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
* Fix handling of empty batches in SumReduceDimsOp As titled * Deferrable async_scheduling finishRun fix Proper order of finishing run operations in deferrable_async_scheduling net * Simplify exception handling in async_scheduling Simplify exception handling, no need to busy wait, thread that processes the last task can finish the run * [C2]worker_coordinator_memorize_worker_ids As titled. This is related to T28689868, where the number of blobs we want to create is equal to the number of worker ids * Add unit test for nets with no type set * Ignore total length argument in sympolic_pad_packed_sequence 1- There was a mistake in the code that total_length was added to the wrong symbolic function (pack_padded_sequence) instead of (pad_packed_sequence) 2- No need to throw an exception if total_length is given since it is only used to enable data_parallel training on multi-gpus and doesn't have anything to do with onnx export, so just ignore it. https://fburl.com/tk4gciqp * Add support for MKLDNN to async_scheduling Just add MKLDNN as a possible CPU option to async_scheduling's pool function * [AuFL][ensemble] support branch output for prediction This diff supports using predictions from different branches and thus enables model ensembling (not fully independent). * Fix a bug in add_loss in layer_model_helper As titled. * Support lradaption for adam 1.lr adaption operator 2.apply to dense adam * Perf tweaks for async_scheduling Restore single pool option + remove unnecessary (no-ops) calls * add quantization to SparseSimdAdagradOp add a bunch of quantization signatures to SparseSimdAdagradOp, implementations to come next * [sr] [codemod] Change all SR callsites to use new API @allow-large-files This diff refactors all callsites of SR to use the slightly changed API introduced in the diff below. Really what this means is that you need to include the correct header. Also if you were using `ClientFactory::newFactory` you need to not prefix it with `ClientFactory::`. ``` cd ~/fbsource/fbcode find ./ -type f -exec sed -i -e 's:#include "servicerouter/client/cpp2/ClientFactory.h":#include "servicerouter/client/cpp2/ServiceRouter.h":' -e 's:#include <servicerouter/client/cpp2/ClientFactory.h>:#include <servicerouter/client/cpp2/ServiceRouter.h>:' -e 's/ClientFactory::newFactory(/newFactory(/g' {} \; ``` Also manually fixed spots that couldn't be done automatically (or broke because they depended on transitive includes). * Back out "Fix handling of empty batches in SumReduceDimsOp" Original commit changeset: 282da1730cc2 This commit is blocking the Github->fbcode sync, which really needs to get merged ASAP. D7881937 which this diff depends on will be reverted in the sync D7990948 which causes this to break. The sync diff cannot be patched with this reversion because it must be landed against base revision 5c8c099 , and D7881937 must not be included in the sync diff because it is breaking GPU tests that are not available in sandcastle : https://ci.pytorch.org/jenkins/job/caffe2-builds/job/py2-cuda8.0-cudnn6-ubuntu16.04-test/3638/console for one example. * Add the flow to support operator benchmark 1) generate model with the operator 2) upload to everstore 3) generate model spec into json file 4) start running the benchmark * [tum][gpu] Connect DPM trainer with flow and unit tests This diff: - Fix some small bugs for Yiming's recent changes to parallelizer, so it suits real use cases. - Add correct tags to the TUM code, so we can do data parallel transform - pass extra info when instantiation. - add unit test for using DPM in TUM model After this diff, we can do simple box, multi-gpu fully-sync trainer for TUM in Fblearner workflow, but may still need to do speed benchmarking. * w/o normalized lradaption for adam dense only The previous lr adaption includes a normalization step when performing the dot product operation. This is not exactly same as what is proposed in the paper. I add normalization as an option. Without it, the operator performs exactly what the paper proposed. With the option, we add the normalization step * [fb] Use SharedPromise in DeferrableAsyncSchedulingNet This code is to simplify DeferrableAsyncSchedulingNet by removing condition variable + small fixes * [tum] implement cuda sparseLengthsMean and LengthsMean as title * Adding an optional parameter to allow use of protobufs in InferShapesAndTypes function. Adding an optional parameter to allow use of protobufs in InferShapesAndTypes function. * Move feature_to_index to FeatureSpec.feature_to_index move feature_to_index to FeatureSpec.feature_to_index to avoid override other fields * [Caffe2] Rename bytes_moved to bytes_written Just a rename in preparation for supporting bytes_read. * [c2] fix ReduceFrontSumOp for empty case by setting 0 otherwise, it may use the results from last iteration when it's empty batch. * [Caffe2] [Int8] Improve Intel CPU performance * [Easy] Improve PrependDim op logging as titled * DBFileReader expand db_path using os.path.expanduser(..) Since there are a lot of possible use cases of `DBFileReader` to read from user home path, like `~/local/sample.db`, I want to save people's trouble of calling `os.path.expanduser(db_path)` themselves. * [Caffe2] Add bytes_read to cost structure We're adding analytical read bytes to cost functions. This extends the structure accordingly for all CostInference defined operators. Additionally, some small bug fixes were performed: 1) Cost functions now extract type information of operands instead of assuming float * Fix sleef on aarch64 for hhvm @bypass-lint Rename flag * Remove duplicated part in caffe2/ideep/operators/conv_op.cc should be sync error * Rename test helper function test_adagrad_sparse_helper to adagrad_sparse_test_helper to avoid confusing pytest
1299 lines
43 KiB
Python
1299 lines
43 KiB
Python
## @package optimizer
|
|
# Module caffe2.python.optimizer
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from collections import namedtuple, defaultdict
|
|
from past.builtins import basestring
|
|
|
|
import numpy as np
|
|
|
|
from caffe2.python import core, scope, workspace
|
|
from caffe2.python.modeling import parameter_info
|
|
from caffe2.proto import caffe2_pb2
|
|
|
|
|
|
_OPTIMIZER_ITERATION_NAME = "optimizer_iteration"
|
|
_LEARNING_RATE_INJECTION = "lr_injection"
|
|
|
|
AuxOptimizerParams = namedtuple("AuxOptimizerParams", ["local", "shared"])
|
|
_optimizer_instance_count = defaultdict(int)
|
|
|
|
|
|
class Optimizer(object):
|
|
def __init__(self):
|
|
self._aux_params = AuxOptimizerParams(local=[], shared=[])
|
|
self._instance_num = _optimizer_instance_count[self.__class__.__name__]
|
|
_optimizer_instance_count[self.__class__.__name__] += 1
|
|
self._lr_multiplier = None
|
|
self._local_lr_multiplier = None
|
|
self._local_lr_multiplier_on_gpu = False
|
|
|
|
'''
|
|
Adds optimization operators to the net for given parameter and its gradient
|
|
Parameter is specified by either 'param' being a ParameterInfo object.
|
|
In this case param.grad has to be set
|
|
|
|
Or by 'param' being a BlobReference and 'grad' being a BlobReference for its
|
|
gradient.
|
|
'''
|
|
def __call__(self, net, param_init_net, param, grad=None):
|
|
if grad is None:
|
|
assert isinstance(param, parameter_info.ParameterInfo), (
|
|
"Expected parameter to be of type ParameterInfo, got {}".format(
|
|
param
|
|
))
|
|
assert param.grad is not None
|
|
else:
|
|
if isinstance(param, basestring):
|
|
param = core.BlobReference(param)
|
|
param = parameter_info.ParameterInfo(
|
|
param_id=None, param=param, grad=grad)
|
|
|
|
self._run(net, param_init_net, param)
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
raise Exception("Not Implemented")
|
|
|
|
def get_cpu_blob_name(self, base_str, node_name=''):
|
|
classname = self.__class__.__name__
|
|
return '%s_%d_%s%s_cpu' % (classname, self._instance_num, base_str, node_name)
|
|
|
|
def get_gpu_blob_name(self, base_str, gpu_id, node_name):
|
|
classname = self.__class__.__name__
|
|
return '%s_%d_%s%s_gpu%d' % (
|
|
classname, self._instance_num, base_str, node_name, gpu_id,
|
|
)
|
|
|
|
def make_unique_blob_name(self, base_str):
|
|
"""
|
|
Returns a blob name that will be unique to the current device
|
|
and optimizer instance.
|
|
"""
|
|
current_scope = scope.CurrentDeviceScope()
|
|
if current_scope is None:
|
|
return self.get_cpu_blob_name(base_str)
|
|
|
|
if current_scope.device_type == caffe2_pb2.CUDA:
|
|
return self.get_gpu_blob_name(
|
|
base_str, current_scope.cuda_gpu_id, current_scope.node_name
|
|
)
|
|
else:
|
|
return self.get_cpu_blob_name(base_str, current_scope.node_name)
|
|
|
|
def build_lr(self, net, param_init_net, base_learning_rate,
|
|
learning_rate_blob=None, policy="fixed",
|
|
iter_val=0, **kwargs):
|
|
if learning_rate_blob is None:
|
|
learning_rate_blob = self.make_unique_blob_name('lr')
|
|
|
|
optimization_iter_blob = _OPTIMIZER_ITERATION_NAME
|
|
if not param_init_net.BlobIsDefined(optimization_iter_blob):
|
|
# Add training operators.
|
|
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
|
|
iteration = param_init_net.ConstantFill(
|
|
[], optimization_iter_blob, shape=[1],
|
|
value=iter_val,
|
|
dtype=core.DataType.INT64)
|
|
iter_mutex = param_init_net.CreateMutex(
|
|
[], ["iteration_mutex"]
|
|
)
|
|
net.AtomicIter([iter_mutex, iteration], [iteration])
|
|
else:
|
|
iteration = param_init_net.GetBlobRef(optimization_iter_blob)
|
|
|
|
if not net.BlobIsDefined(learning_rate_blob):
|
|
# There is one interesting thing here: since we are minimizing, we are
|
|
# doing "descent" so the learning rate is set to be negative.
|
|
lr = net.LearningRate(
|
|
[iteration],
|
|
learning_rate_blob,
|
|
base_lr=-base_learning_rate,
|
|
policy=policy,
|
|
**kwargs
|
|
)
|
|
else:
|
|
lr = net.GetBlobRef(learning_rate_blob)
|
|
|
|
if self._lr_multiplier is not None:
|
|
lr_multiplier = net.CopyFromCPUInput(
|
|
self._lr_multiplier, self.make_unique_blob_name('lr_multiplier')
|
|
)
|
|
|
|
lr = net.Mul(
|
|
[lr, lr_multiplier],
|
|
self.make_unique_blob_name('scaled_lr'),
|
|
broadcast=1,
|
|
)
|
|
|
|
if self._local_lr_multiplier is not None:
|
|
current_scope = scope.CurrentDeviceScope()
|
|
if (current_scope is not None
|
|
and current_scope.device_type == caffe2_pb2.CUDA
|
|
and not self._local_lr_multiplier_on_gpu):
|
|
local_lr_multiplier = net.CopyFromCPUInput(
|
|
self._local_lr_multiplier,
|
|
self.make_unique_blob_name('local_lr_multiplier')
|
|
)
|
|
else:
|
|
local_lr_multiplier = self._local_lr_multiplier
|
|
|
|
lr = net.Mul(
|
|
[lr, local_lr_multiplier],
|
|
self.make_unique_blob_name('local_scaled_lr'),
|
|
broadcast=1,
|
|
)
|
|
|
|
return lr, iteration
|
|
|
|
def add_lr_multiplier(self, lr_multiplier):
|
|
"""
|
|
Set the global learning rate multiplier. If a multiplier already
|
|
existed, this will overwrite the existing multiplier. The multiplier is
|
|
used for all future calls to _run(), unless it is overwritten.
|
|
"""
|
|
self._lr_multiplier = lr_multiplier
|
|
|
|
def _add_local_lr_multiplier(self, local_lr_multiplier, is_gpu_blob=False):
|
|
"""
|
|
Set the local learning rate multiplier. This local multiplier is
|
|
multiplied with the global learning rate multiplier if it exists. As
|
|
with the global learning rate multiplier, this multiplier will be
|
|
used for all future calls to _run(), so please call
|
|
_clear_local_lr_multiplier() at the beginning of the optimizer's _run()
|
|
before optionally calling this function.
|
|
"""
|
|
self._local_lr_multiplier = local_lr_multiplier
|
|
self._local_lr_multiplier_on_gpu = is_gpu_blob
|
|
|
|
def _clear_local_lr_multiplier(self):
|
|
self._local_lr_multiplier = None
|
|
self._local_lr_multiplier_on_gpu = False
|
|
|
|
@staticmethod
|
|
def dedup(net, sparse_dedup_aggregator, grad):
|
|
assert isinstance(grad, core.GradientSlice), (
|
|
"Dedup only works for sparse gradient, got {}".format(grad))
|
|
if sparse_dedup_aggregator:
|
|
return net.DeduplicateGradientSlices(
|
|
grad, aggregator=sparse_dedup_aggregator)
|
|
else:
|
|
return grad
|
|
|
|
def get_auxiliary_parameters(self):
|
|
"""Returns a list of auxiliary parameters.
|
|
|
|
Returns:
|
|
aux_params: A namedtuple, AuxParams.
|
|
|
|
aux_params.local stores a list of blobs. Each blob is a local
|
|
auxiliary parameter. A local auxiliary parameter is a parameter in
|
|
parallel to a learning rate parameter. Take adagrad as an example,
|
|
the local auxiliary parameter is the squared sum parameter, because
|
|
every learning rate has a squared sum associated with it.
|
|
|
|
aux_params.shared also stores a list of blobs. Each blob is a shared
|
|
auxiliary parameter. A shared auxiliary parameter is a parameter
|
|
that is shared across all the learning rate parameters. Take adam as
|
|
an example, the iteration parameter is a shared parameter, because
|
|
all the learning rates share the same iteration parameter.
|
|
"""
|
|
return self._aux_params
|
|
|
|
# TODO(xlwang): In transfer learning, parameter initialized from pretrained
|
|
# model might require a different learning rate than otherwise initialized.
|
|
# To this end, here we implement a python solution where
|
|
# `base_learning_rate` is scaled by `scale`, by calling
|
|
# `scale_learning_rate`; Alternatively, we can achieve same effect by
|
|
# rewriting the LearningRate operator in C++
|
|
# Note that it is the responsibility of specific optimizer to decide what
|
|
# logic should be used for `scale_learning_rate`
|
|
def scale_learning_rate(self, *args, **kwargs):
|
|
raise NotImplementedError(
|
|
"Optimizer Need to Implement `scale_learning_rate` method.")
|
|
|
|
|
|
class SgdOptimizer(Optimizer):
|
|
def __init__(self, base_learning_rate=0.01, policy='fixed',
|
|
momentum=0.0, nesterov=1, sparse_dedup_aggregator=None,
|
|
lars=None, **kwargs):
|
|
super(SgdOptimizer, self).__init__()
|
|
self.base_learning_rate = base_learning_rate
|
|
self.policy = policy
|
|
self.momentum = momentum
|
|
self.nesterov = nesterov
|
|
self.sparse_dedup_aggregator = sparse_dedup_aggregator
|
|
self.lars = lars
|
|
self.init_kwargs = kwargs
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
param = param_info.blob
|
|
grad = param_info.grad
|
|
if self.base_learning_rate == 0:
|
|
return
|
|
assert self.base_learning_rate > 0, (
|
|
"Expect positive base learning rate, got {}".format(
|
|
self.base_learning_rate))
|
|
|
|
self._clear_local_lr_multiplier()
|
|
|
|
# TODO(zqq): support LARS for sparse parameters
|
|
if self.lars is not None and not isinstance(grad, core.GradientSlice):
|
|
assert self.lars >= 0, (
|
|
'Lars offset must be nonnegative, got {}'.format(self.lars))
|
|
lr_lars_multiplier = net.Lars(
|
|
[param, grad],
|
|
self.make_unique_blob_name(str(param) + "_lars"),
|
|
offset=self.lars)
|
|
current_scope = scope.CurrentDeviceScope()
|
|
self._add_local_lr_multiplier(
|
|
lr_lars_multiplier,
|
|
is_gpu_blob=(current_scope is not None
|
|
and current_scope.device_type == caffe2_pb2.CUDA),
|
|
)
|
|
|
|
# We need negative sign for LR when used directly with WeightedSum
|
|
# below.
|
|
lr_sign = -1 if self.momentum else 1
|
|
lr, _ = self.build_lr(
|
|
net, param_init_net,
|
|
base_learning_rate=self.base_learning_rate * lr_sign,
|
|
policy=self.policy,
|
|
**(self.init_kwargs)
|
|
)
|
|
|
|
dev = scope.CurrentDeviceScope()
|
|
if dev is None:
|
|
dev = core.DeviceOption(caffe2_pb2.CPU)
|
|
|
|
# Each GPU/CPU must have its own ONE blob, thus modify the name
|
|
# to include device information.
|
|
ONE = param_init_net.ConstantFill(
|
|
[],
|
|
"ONE_{}_{}{}".format(dev.device_type, dev.cuda_gpu_id, dev.node_name),
|
|
shape=[1],
|
|
value=1.0
|
|
)
|
|
|
|
self._aux_params.shared.append(ONE)
|
|
|
|
if self.momentum > 0:
|
|
momentum_data = param_init_net.ConstantFill(
|
|
param, str(param) + "_momentum", value=0.)
|
|
self._aux_params.local.append(momentum_data)
|
|
|
|
if isinstance(grad, core.GradientSlice):
|
|
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
|
|
if self.momentum > 0.:
|
|
net.SparseMomentumSGDUpdate(
|
|
[grad.values, momentum_data, lr, param, grad.indices],
|
|
[grad.values, momentum_data, param],
|
|
momentum=self.momentum,
|
|
nesterov=self.nesterov)
|
|
else:
|
|
net.ScatterWeightedSum(
|
|
[param, ONE, grad.indices, grad.values, lr],
|
|
param
|
|
)
|
|
else:
|
|
if self.momentum > 0.:
|
|
net.MomentumSGDUpdate(
|
|
[grad, momentum_data, lr, param],
|
|
[grad, momentum_data, param],
|
|
momentum=self.momentum,
|
|
nesterov=self.nesterov)
|
|
else:
|
|
coeff = lr
|
|
|
|
net.WeightedSum(
|
|
[param, ONE, grad, coeff],
|
|
param
|
|
)
|
|
|
|
def scale_learning_rate(self, scale):
|
|
self.base_learning_rate *= scale
|
|
return
|
|
|
|
|
|
class MultiPrecisionSgdOptimizer(SgdOptimizer):
|
|
def __init__(self, base_learning_rate=0.1, momentum=0.0,
|
|
policy="fixed", nesterov=1, sparse_dedup_aggregator=None,
|
|
**kwargs):
|
|
super(MultiPrecisionSgdOptimizer, self).__init__(
|
|
base_learning_rate=base_learning_rate,
|
|
policy=policy,
|
|
momentum=momentum,
|
|
nesterov=nesterov,
|
|
sparse_dedup_aggregator=sparse_dedup_aggregator,
|
|
**kwargs
|
|
)
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
param = param_info.blob
|
|
param_fp32 = param_info.blob_copy[core.DataType.FLOAT] \
|
|
if param_info.blob_copy is not None else None
|
|
|
|
# If we have a straight fp32 parameter, run the base class
|
|
if param_fp32 is None:
|
|
return SgdOptimizer._run(self, net, param_init_net, param_info)
|
|
|
|
grad = param_info.grad
|
|
if self.base_learning_rate == 0:
|
|
return
|
|
assert self.base_learning_rate > 0, (
|
|
"Expect positive base learning rate, got {}".format(
|
|
self.base_learning_rate))
|
|
|
|
lr, _ = self.build_lr(
|
|
net, param_init_net,
|
|
base_learning_rate=-self.base_learning_rate,
|
|
policy=self.policy,
|
|
**(self.init_kwargs)
|
|
)
|
|
|
|
momentum_data = param_init_net.ConstantFill(
|
|
param_fp32, str(param) + "_momentum", value=0.)
|
|
self._aux_params.local.append(momentum_data)
|
|
|
|
assert not isinstance(grad, core.GradientSlice), (
|
|
"MultiPrecisionSgd does not support sparse gradients")
|
|
|
|
# Copy gradient to fp32
|
|
grad_fp32 = net.HalfToFloat(grad, grad + "_fp32")
|
|
|
|
# update (fused) in fp32
|
|
net.MomentumSGDUpdate(
|
|
[grad_fp32, momentum_data, lr, param_fp32],
|
|
[grad_fp32, momentum_data, param_fp32],
|
|
momentum=self.momentum,
|
|
nesterov=self.nesterov)
|
|
|
|
# Copy updated param back to fp16
|
|
net.FloatToHalf(param_fp32, param)
|
|
|
|
|
|
class FP16SgdOptimizer(SgdOptimizer):
|
|
def __init__(self, base_learning_rate=0.1, momentum=0.0,
|
|
policy="fixed", nesterov=1, weight_decay=0.0001,
|
|
sparse_dedup_aggregator=None,
|
|
**kwargs):
|
|
super(FP16SgdOptimizer, self).__init__(
|
|
base_learning_rate=base_learning_rate,
|
|
policy=policy,
|
|
momentum=momentum,
|
|
nesterov=nesterov,
|
|
sparse_dedup_aggregator=sparse_dedup_aggregator,
|
|
**kwargs
|
|
)
|
|
self.weight_decay = weight_decay
|
|
|
|
def _run(self, net, param_init_net, param_info, fp32_update=False):
|
|
|
|
fp32_update_flag = 0
|
|
param_name = str(param_info.blob)
|
|
|
|
# should only be triggered in FP16 training by SpatialBN, which
|
|
# requires FP32 params in CuDNN.
|
|
if param_name.find("spatbn") != -1:
|
|
fp32_update = True
|
|
|
|
if fp32_update:
|
|
# doing a 32bit update
|
|
# Have to assume param_info.blob is FP32 as there is no way
|
|
# (that i currently know of) to query a blob's type in python
|
|
fp32_update_flag = 1
|
|
param = param_info.blob
|
|
param_fp32 = param_info.blob
|
|
else:
|
|
if param_info.blob_copy is None:
|
|
# doing a 32bit update
|
|
# Have to assume param_info.blob is FP32 as there is no way
|
|
# (that i currently know of) to query a blob's type in python
|
|
fp32_update_flag = 1
|
|
param = param_info.blob
|
|
param_fp32 = param_info.blob
|
|
else:
|
|
if core.DataType.FLOAT in param_info.blob_copy:
|
|
param = param_info.blob
|
|
param_fp32 = param_info.blob_copy[core.DataType.FLOAT]
|
|
elif core.DataType.FLOAT16 in param_info.blob_copy:
|
|
param = param_info.blob_copy[core.DataType.FLOAT16]
|
|
param_fp32 = param_info.blob
|
|
else:
|
|
assert (False), (
|
|
"Unrecognized parameter format to be updated "
|
|
"by FP16 Optimizer. Parameter: {}".format(param_info.name)
|
|
)
|
|
|
|
grad = param_info.grad
|
|
|
|
if self.base_learning_rate == 0:
|
|
return
|
|
assert self.base_learning_rate > 0, (
|
|
"Expect positive base learning rate, got {}".format(
|
|
self.base_learning_rate))
|
|
|
|
lr, _ = self.build_lr(
|
|
net, param_init_net,
|
|
base_learning_rate=-self.base_learning_rate,
|
|
policy=self.policy,
|
|
**(self.init_kwargs)
|
|
)
|
|
|
|
momentum_data_fp32 = param_init_net.ConstantFill(
|
|
param_fp32, str(param) + "_momentum_fp32", value=0.)
|
|
|
|
momentum_data = param_init_net.FloatToHalf(
|
|
momentum_data_fp32, str(param) + "_momentum")
|
|
|
|
self._aux_params.local.append(momentum_data)
|
|
|
|
assert not isinstance(grad, core.GradientSlice), (
|
|
"FP16Sgd does not support sparse gradients")
|
|
|
|
if fp32_update_flag == 0:
|
|
net.FP16MomentumSGDUpdate(
|
|
[grad, momentum_data, lr, param],
|
|
[grad, momentum_data, param],
|
|
momentum=self.momentum,
|
|
nesterov=self.nesterov,
|
|
weight_decay=self.weight_decay)
|
|
else:
|
|
# flag set to 1, therefore doing FP32 update
|
|
net.FP32MomentumSGDUpdate(
|
|
[grad, momentum_data_fp32, lr, param],
|
|
[grad, momentum_data_fp32, param],
|
|
momentum=self.momentum,
|
|
nesterov=self.nesterov,
|
|
weight_decay=self.weight_decay)
|
|
|
|
|
|
class WeightDecayBuilder(Optimizer):
|
|
def __init__(self, weight_decay):
|
|
self.weight_decay = weight_decay
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
dev = scope.CurrentDeviceScope()
|
|
if dev is None:
|
|
dev = core.DeviceOption(caffe2_pb2.CPU)
|
|
|
|
ONE = param_init_net.ConstantFill(
|
|
[],
|
|
"ONE_{}_{}".format(dev.device_type, dev.cuda_gpu_id),
|
|
shape=[1],
|
|
value=1.0
|
|
)
|
|
WD = param_init_net.ConstantFill(
|
|
[], "wd_{}_{}".format(dev.device_type, dev.cuda_gpu_id),
|
|
shape=[1], value=self.weight_decay
|
|
)
|
|
|
|
if isinstance(param_info.grad, core.GradientSlice):
|
|
raise ValueError(
|
|
"Weight decay does not yet support sparse gradients")
|
|
else:
|
|
net.WeightedSum(
|
|
[param_info.grad, ONE, param_info.blob, WD],
|
|
param_info.grad,
|
|
)
|
|
|
|
|
|
class AdagradOptimizer(Optimizer):
|
|
def __init__(self, alpha=0.01, epsilon=1e-4, decay=1, policy="fixed",
|
|
sparse_dedup_aggregator=None, rowWise=False, engine='',
|
|
lars=None, output_effective_lr=False,
|
|
output_effective_lr_and_update=False, **kwargs):
|
|
super(AdagradOptimizer, self).__init__()
|
|
self.alpha = alpha
|
|
self.epsilon = epsilon
|
|
self.decay = decay
|
|
self.policy = policy
|
|
self.sparse_dedup_aggregator = sparse_dedup_aggregator
|
|
self.rowWise = rowWise
|
|
self.engine = engine
|
|
self.lars = lars
|
|
self.output_effective_lr = output_effective_lr
|
|
self.output_effective_lr_and_update = output_effective_lr_and_update
|
|
self.init_kwargs = kwargs
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
param = param_info.blob
|
|
grad = param_info.grad
|
|
|
|
if self.alpha <= 0:
|
|
return
|
|
|
|
self._clear_local_lr_multiplier()
|
|
|
|
if self.lars is not None and not isinstance(grad, core.GradientSlice):
|
|
assert self.lars >= 0, (
|
|
'Lars offset must be nonnegative, got {}'.format(self.lars))
|
|
lr_lars_multiplier = net.Lars(
|
|
[param, grad],
|
|
self.make_unique_blob_name(str(param) + "_lars"),
|
|
offset=self.lars)
|
|
current_scope = scope.CurrentDeviceScope()
|
|
self._add_local_lr_multiplier(
|
|
lr_lars_multiplier,
|
|
is_gpu_blob=(current_scope is not None
|
|
and current_scope.device_type == caffe2_pb2.CUDA),
|
|
)
|
|
|
|
lr, _ = self.build_lr(
|
|
net, param_init_net,
|
|
base_learning_rate=self.alpha,
|
|
policy=self.policy,
|
|
**(self.init_kwargs)
|
|
)
|
|
|
|
if self.rowWise:
|
|
shapes, types = workspace.InferShapesAndTypes([param_init_net])
|
|
if str(param) not in shapes:
|
|
# Type/shape inference is not available for this param, fallback
|
|
# on Shape/Slice logic
|
|
shape = param_init_net.Shape(param, str(param) + "_shape")
|
|
num_rows = param_init_net.Slice(
|
|
[shape],
|
|
str(shape) + "_numrows",
|
|
starts=[0], ends=[1]
|
|
)
|
|
param_squared_sum = param_init_net.ConstantFill(
|
|
num_rows,
|
|
str(param) + "_avg_squared_sum",
|
|
input_as_shape=1,
|
|
value=0.0
|
|
)
|
|
else:
|
|
param_squared_sum = param_init_net.ConstantFill(
|
|
[],
|
|
str(param) + "_avg_squared_sum",
|
|
shape=[shapes[str(param)][0]],
|
|
value=0.0
|
|
)
|
|
|
|
else:
|
|
param_squared_sum = param_init_net.ConstantFill(
|
|
[param],
|
|
str(param) + "_squared_sum",
|
|
value=0.0
|
|
)
|
|
|
|
self._aux_params.local.append(param_squared_sum)
|
|
|
|
if self.rowWise:
|
|
assert isinstance(grad, core.GradientSlice),\
|
|
'If SparseAdagrad with rowWise=True, gradient must be '\
|
|
'a gradientslice. PLease ensure that rowWise is not enabled '\
|
|
'for the dense Adagrad optimizer, as it is not supported.'
|
|
if isinstance(grad, core.GradientSlice):
|
|
assert self.decay == 1.,\
|
|
'Decay is not implemented for SparseAdagrad and must be set to 1'
|
|
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
|
|
if self.rowWise:
|
|
op = 'RowWiseSparseAdagrad'
|
|
else:
|
|
op = 'SparseAdagrad'
|
|
net.__getattr__(op)(
|
|
[param, param_squared_sum, grad.indices, grad.values, lr],
|
|
[param, param_squared_sum],
|
|
epsilon=self.epsilon,
|
|
engine=self.engine
|
|
)
|
|
else:
|
|
output_args = [param, param_squared_sum]
|
|
if self.output_effective_lr_and_update:
|
|
output_args.append(str(param) + '_effective_lr')
|
|
output_args.append(str(param) + '_update')
|
|
elif self.output_effective_lr:
|
|
output_args.append(str(param) + '_effective_lr')
|
|
|
|
net.Adagrad(
|
|
[param, param_squared_sum, grad, lr],
|
|
output_args,
|
|
epsilon=self.epsilon,
|
|
decay=float(self.decay),
|
|
engine=self.engine
|
|
)
|
|
|
|
def scale_learning_rate(self, scale):
|
|
self.alpha *= scale
|
|
return
|
|
|
|
|
|
class FtrlOptimizer(Optimizer):
|
|
def __init__(self, alpha=0.01, beta=1e-4, lambda1=0, lambda2=0,
|
|
sparse_dedup_aggregator=None, engine=''):
|
|
super(FtrlOptimizer, self).__init__()
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
self.lambda1 = lambda1
|
|
self.lambda2 = lambda2
|
|
self.sparse_dedup_aggregator = sparse_dedup_aggregator
|
|
self.engine = engine
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
param = param_info.blob
|
|
grad = param_info.grad
|
|
|
|
if self.alpha <= 0:
|
|
return
|
|
|
|
nz = param_init_net.ConstantFill(
|
|
[param],
|
|
str(param) + "_ftrl_nz",
|
|
extra_shape=[2],
|
|
value=0.0
|
|
)
|
|
self._aux_params.local.append(nz)
|
|
if isinstance(grad, core.GradientSlice):
|
|
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
|
|
net.SparseFtrl(
|
|
[param, nz, grad.indices, grad.values],
|
|
[param, nz],
|
|
engine=self.engine,
|
|
alpha=self.alpha,
|
|
beta=self.beta,
|
|
lambda1=self.lambda1,
|
|
lambda2=self.lambda2
|
|
)
|
|
else:
|
|
net.Ftrl(
|
|
[param, nz, grad],
|
|
[param, nz],
|
|
engine=self.engine,
|
|
alpha=self.alpha,
|
|
beta=self.beta,
|
|
lambda1=self.lambda1,
|
|
lambda2=self.lambda2
|
|
)
|
|
|
|
def scale_learning_rate(self, scale):
|
|
self.alpha *= scale
|
|
return
|
|
|
|
|
|
class AdamOptimizer(Optimizer):
|
|
def __init__(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
|
|
policy='fixed', use_lr_adaption=False, lr_alpha=0.01,
|
|
normalized_lr_adaption=True, sparse_dedup_aggregator=None,
|
|
rowWise=False, engine='', **kwargs):
|
|
super(AdamOptimizer, self).__init__()
|
|
self.alpha = alpha
|
|
self.beta1 = beta1
|
|
self.beta2 = beta2
|
|
self.epsilon = epsilon
|
|
self.policy = policy
|
|
self.use_lr_adaption = use_lr_adaption
|
|
self.lr_alpha = lr_alpha
|
|
self.normalized_lr_adaption = normalized_lr_adaption
|
|
self.sparse_dedup_aggregator = sparse_dedup_aggregator
|
|
self.rowWise = rowWise
|
|
self.engine = engine
|
|
self.init_kwargs = kwargs
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
param = param_info.blob
|
|
grad = param_info.grad
|
|
|
|
if self.alpha <= 0:
|
|
return
|
|
|
|
lr, iteration = self.build_lr(
|
|
net, param_init_net,
|
|
base_learning_rate=self.alpha,
|
|
policy=self.policy,
|
|
**(self.init_kwargs)
|
|
)
|
|
|
|
if self.use_lr_adaption:
|
|
effective_grad = param_init_net.ConstantFill(
|
|
[param],
|
|
param + "_effgrad",
|
|
value=0.0
|
|
)
|
|
self._aux_params.local.append(effective_grad)
|
|
net.LearningRateAdaption(
|
|
[lr, grad, effective_grad],
|
|
[lr],
|
|
lr_alpha=self.lr_alpha,
|
|
normalized_lr_adaption=self.normalized_lr_adaption)
|
|
|
|
m1 = param_init_net.ConstantFill(
|
|
[param],
|
|
param + "_first_moment",
|
|
value=0.0
|
|
)
|
|
|
|
if self.rowWise:
|
|
shapes, types = workspace.InferShapesAndTypes([param_init_net])
|
|
m2 = param_init_net.ConstantFill(
|
|
[],
|
|
param + "_avg_second_moment",
|
|
shape=[shapes[param][0]],
|
|
value=0.0
|
|
)
|
|
else:
|
|
m2 = param_init_net.ConstantFill(
|
|
[param],
|
|
param + "_second_moment",
|
|
value=0.0
|
|
)
|
|
|
|
self._aux_params.shared.append(iteration)
|
|
self._aux_params.local.append(m1)
|
|
self._aux_params.local.append(m2)
|
|
|
|
if self.rowWise:
|
|
assert isinstance(grad, core.GradientSlice),\
|
|
'If SparseAdam with rowWise=True, gradient must be '\
|
|
'a gradientslice. PLease ensure that rowWise is not enabled '\
|
|
'for the dense Adam optimizer, as it is not supported.'
|
|
if isinstance(grad, core.GradientSlice):
|
|
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
|
|
if self.rowWise:
|
|
op = 'RowWiseSparseAdam'
|
|
else:
|
|
op = 'SparseAdam'
|
|
net.__getattr__(op)(
|
|
[param, m1, m2, grad.indices, grad.values, lr, iteration],
|
|
[param, m1, m2],
|
|
beta1=self.beta1,
|
|
beta2=self.beta2,
|
|
epsilon=self.epsilon
|
|
)
|
|
|
|
else:
|
|
if self.use_lr_adaption:
|
|
net.Adam(
|
|
[param, m1, m2, grad, lr, iteration],
|
|
[param, m1, m2, effective_grad],
|
|
beta1=self.beta1,
|
|
beta2=self.beta2,
|
|
epsilon=self.epsilon)
|
|
else:
|
|
net.Adam(
|
|
[param, m1, m2, grad, lr, iteration],
|
|
[param, m1, m2],
|
|
beta1=self.beta1,
|
|
beta2=self.beta2,
|
|
epsilon=self.epsilon)
|
|
|
|
def scale_learning_rate(self, scale):
|
|
self.alpha *= scale
|
|
return
|
|
|
|
|
|
class YellowFinOptimizer(Optimizer):
|
|
"""YellowFin: An automatic tuner for momentum SGD
|
|
|
|
See https://arxiv.org/abs/1706.03471 for more details. This implementation
|
|
has separate learning rate and momentum per each parameter."""
|
|
|
|
def __init__(self,
|
|
alpha=0.1,
|
|
mu=0.0,
|
|
beta=0.999,
|
|
curv_win_width=20,
|
|
zero_debias=True,
|
|
epsilon=0.1**6,
|
|
policy='fixed',
|
|
sparse_dedup_aggregator=None,
|
|
**kwargs):
|
|
super(YellowFinOptimizer, self).__init__()
|
|
self.alpha = alpha
|
|
self.mu = mu
|
|
self.beta = beta
|
|
self.curv_win_width = curv_win_width
|
|
self.zero_debias = zero_debias
|
|
self.epsilon = epsilon
|
|
self.policy = policy
|
|
self.sparse_dedup_aggregator = sparse_dedup_aggregator
|
|
self.init_kwargs = kwargs
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
|
|
# Note: This is number of persistent scalars in YellowFin optimizer.
|
|
# It should always be the number of scalars being used. The same
|
|
# number should be used in class for the operation.
|
|
SCALARS_MEMORY_SIZE = 5
|
|
|
|
param = param_info.blob
|
|
grad = param_info.grad
|
|
moment = param_init_net.ConstantFill(
|
|
[param],
|
|
param + "_moment",
|
|
value=0.0
|
|
)
|
|
curv_win = param_init_net.ConstantFill(
|
|
[],
|
|
param + "_curv_win",
|
|
shape=[self.curv_win_width],
|
|
value=0.0
|
|
)
|
|
g_avg = param_init_net.ConstantFill(
|
|
[param],
|
|
param + "_g_avg",
|
|
value=0.0
|
|
)
|
|
g2_avg = param_init_net.ConstantFill(
|
|
[param],
|
|
param + "_g2_avg",
|
|
value=0.0
|
|
)
|
|
lr_avg = param_init_net.ConstantFill(
|
|
[],
|
|
param + "_lr_avg",
|
|
shape=[1],
|
|
value=self.alpha
|
|
)
|
|
mu_avg = param_init_net.ConstantFill(
|
|
[],
|
|
param + "_mu_avg",
|
|
shape=[1],
|
|
value=self.mu
|
|
)
|
|
scalars_memory = param_init_net.ConstantFill(
|
|
[],
|
|
param + "_scalars_memory",
|
|
shape=[SCALARS_MEMORY_SIZE],
|
|
value=0.0
|
|
)
|
|
|
|
assert self.alpha > 0
|
|
assert not isinstance(grad, core.GradientSlice), \
|
|
"YellowFin does not support sparse gradients"
|
|
|
|
if not param_init_net.BlobIsDefined(_OPTIMIZER_ITERATION_NAME):
|
|
# Add training operators.
|
|
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
|
|
iteration = param_init_net.ConstantFill(
|
|
[],
|
|
_OPTIMIZER_ITERATION_NAME,
|
|
shape=[1],
|
|
value=0,
|
|
dtype=core.DataType.INT64)
|
|
iter_mutex = param_init_net.CreateMutex([],
|
|
["iteration_mutex"])
|
|
net.AtomicIter([iter_mutex, iteration], [iteration])
|
|
else:
|
|
iteration = param_init_net.GetBlobRef(_OPTIMIZER_ITERATION_NAME)
|
|
|
|
self._aux_params.shared.append(iteration)
|
|
self._aux_params.local.append(moment)
|
|
self._aux_params.local.append(lr_avg)
|
|
self._aux_params.local.append(mu_avg)
|
|
self._aux_params.local.append(curv_win)
|
|
self._aux_params.local.append(g_avg)
|
|
self._aux_params.local.append(g2_avg)
|
|
self._aux_params.local.append(scalars_memory)
|
|
|
|
yf_in_out_args = [
|
|
param,
|
|
moment,
|
|
lr_avg,
|
|
mu_avg,
|
|
curv_win,
|
|
g_avg,
|
|
g2_avg,
|
|
scalars_memory
|
|
]
|
|
|
|
net.YellowFin(
|
|
yf_in_out_args + [grad, iteration],
|
|
yf_in_out_args,
|
|
beta=self.beta,
|
|
epsilon=self.epsilon,
|
|
curv_win_width=self.curv_win_width,
|
|
zero_debias=self.zero_debias)
|
|
|
|
def scale_learning_rate(self, scale):
|
|
self.alpha *= scale
|
|
return
|
|
|
|
|
|
class RmsPropOptimizer(Optimizer):
|
|
def __init__(
|
|
self,
|
|
alpha=0.01,
|
|
decay=0.9,
|
|
momentum=0.0,
|
|
epsilon=1e-5,
|
|
policy='fixed',
|
|
engine='',
|
|
**kwargs
|
|
):
|
|
super(RmsPropOptimizer, self).__init__()
|
|
self.alpha = alpha
|
|
self.decay = decay
|
|
self.momentum = momentum
|
|
self.epsilon = epsilon
|
|
self.policy = policy
|
|
self.engine = engine
|
|
self.init_kwargs = kwargs
|
|
|
|
def _run(self, net, param_init_net, param_info):
|
|
param = param_info.blob
|
|
grad = param_info.grad
|
|
|
|
assert self.alpha > 0
|
|
assert not isinstance(grad, core.GradientSlice), \
|
|
"RmsPropOptimizer doesn't support sparse gradients"
|
|
|
|
dev = scope.CurrentDeviceScope()
|
|
if dev is None:
|
|
dev = core.DeviceOption(caffe2_pb2.CPU)
|
|
|
|
ONE = param_init_net.ConstantFill(
|
|
[],
|
|
"ONE_{}_{}".format(dev.device_type, dev.cuda_gpu_id),
|
|
shape=[1],
|
|
value=1.0
|
|
)
|
|
|
|
lr, _ = self.build_lr(
|
|
net,
|
|
param_init_net,
|
|
base_learning_rate=-self.alpha,
|
|
policy=self.policy,
|
|
**(self.init_kwargs)
|
|
)
|
|
|
|
grad_o = param_init_net.ConstantFill(
|
|
[param],
|
|
str(param) + "_grad_o",
|
|
values=0.0,
|
|
)
|
|
|
|
ms = param_init_net.ConstantFill(
|
|
[param],
|
|
str(param) + "_mean_squares",
|
|
values=0.0,
|
|
)
|
|
|
|
mom = param_init_net.ConstantFill(
|
|
[param],
|
|
str(param) + "_momentum",
|
|
values=0.0,
|
|
)
|
|
|
|
self._aux_params.local.append(ms)
|
|
self._aux_params.local.append(mom)
|
|
|
|
net.RmsProp(
|
|
[grad, ms, mom, ONE],
|
|
[grad_o, ms, mom],
|
|
decay=self.decay,
|
|
momentum=self.momentum,
|
|
epsilon=self.epsilon,
|
|
engine=self.engine,
|
|
)
|
|
|
|
net.MomentumSGDUpdate(
|
|
[grad_o, mom, lr, param],
|
|
[grad_o, mom, param],
|
|
)
|
|
|
|
def scale_learning_rate(self, scale):
|
|
self.alpha *= scale
|
|
return
|
|
|
|
|
|
def _get_param_to_device(model):
|
|
# Infer blob devices by going through the net and param_init_net
|
|
# ops and observing the device used to create or use the blob.
|
|
param_to_device = core.InferBlobDevices(model.net)
|
|
param_to_device.update(core.InferBlobDevices(model.param_init_net))
|
|
return param_to_device
|
|
|
|
|
|
def get_param_device(param_name, grad, param_to_device=None, default_device=None):
|
|
device = default_device
|
|
param_to_device = param_to_device or {}
|
|
# We first check if parameter's device has been inferred. If not,
|
|
# we check the gradient. This can happen if parameter is not output
|
|
# by any blob but created by a FetchBlob.
|
|
if param_name in param_to_device:
|
|
device = param_to_device[param_name]
|
|
else:
|
|
if isinstance(grad, core.GradientSlice):
|
|
grad = grad
|
|
if str(grad.values) in param_to_device:
|
|
device = param_to_device[str(grad.values)]
|
|
elif str(grad.indices) in param_to_device:
|
|
device = param_to_device[str(grad.indices)]
|
|
else:
|
|
grad_name = str(grad)
|
|
if grad_name in param_to_device:
|
|
device = param_to_device[grad_name]
|
|
|
|
assert device is not None,\
|
|
"Cannot infer device for {}: no op creates it".format(param_name)
|
|
return device
|
|
|
|
|
|
def get_lr_injection():
|
|
"""
|
|
Gets current value for lr_injection, a multiplier for all base
|
|
learning rates.
|
|
Must set allow_lr_injection=True when building optimizer, as it
|
|
relies on synchronization over CPU.
|
|
"""
|
|
return workspace.FetchBlob(_LEARNING_RATE_INJECTION)
|
|
|
|
|
|
def set_lr_injection(lr_injection_value):
|
|
"""
|
|
Sets lr_injection, a multiplier for all base learning rates.
|
|
Must set allow_lr_injection=True when building optimizer, as it
|
|
relies on synchronization over CPU.
|
|
"""
|
|
workspace.FeedBlob(
|
|
_LEARNING_RATE_INJECTION,
|
|
np.array(
|
|
[float(lr_injection_value)],
|
|
dtype=np.float32,
|
|
),
|
|
)
|
|
|
|
|
|
def _calc_norm_ratio(
|
|
model, params, name_scope, param_to_device, max_gradient_norm
|
|
):
|
|
with core.NameScope(name_scope):
|
|
grad_squared_sums = []
|
|
for i, param in enumerate(params):
|
|
device = get_param_device(
|
|
str(param.blob), param.grad, param_to_device
|
|
)
|
|
|
|
with core.DeviceScope(device):
|
|
grad = (
|
|
param.grad
|
|
if not isinstance(
|
|
param.grad,
|
|
core.GradientSlice,
|
|
) else param.grad.values
|
|
)
|
|
|
|
grad_squared_sum_name = 'grad_{}_squared_sum'.format(i)
|
|
grad_squared_sum = model.net.SumSqrElements(
|
|
grad,
|
|
grad_squared_sum_name,
|
|
)
|
|
grad_squared_sum_cpu = model.net.EnsureCPUOutput(
|
|
grad_squared_sum
|
|
)
|
|
grad_squared_sums.append(grad_squared_sum_cpu)
|
|
|
|
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
|
|
grad_squared_full_sum = model.net.Sum(
|
|
grad_squared_sums,
|
|
'grad_squared_full_sum',
|
|
)
|
|
global_norm = model.net.Pow(
|
|
grad_squared_full_sum,
|
|
'global_norm',
|
|
exponent=0.5,
|
|
)
|
|
clip_norm = model.param_init_net.ConstantFill(
|
|
[],
|
|
'clip_norm',
|
|
shape=[],
|
|
value=float(max_gradient_norm),
|
|
)
|
|
max_norm = model.net.Max(
|
|
[global_norm, clip_norm],
|
|
'max_norm',
|
|
)
|
|
norm_ratio = model.net.Div(
|
|
[clip_norm, max_norm],
|
|
'norm_ratio',
|
|
)
|
|
return norm_ratio
|
|
|
|
|
|
def _build(
|
|
model,
|
|
optimizer,
|
|
weights_only=False,
|
|
use_param_info_optim=True,
|
|
max_gradient_norm=None,
|
|
allow_lr_injection=False,
|
|
):
|
|
param_to_device = _get_param_to_device(model)
|
|
|
|
# Validate there are no duplicate params
|
|
model.Validate()
|
|
|
|
params = []
|
|
for param_info in model.GetOptimizationParamInfo():
|
|
if weights_only and param_info.blob not in model.weights:
|
|
continue
|
|
params.append(param_info)
|
|
|
|
lr_multiplier = None
|
|
if max_gradient_norm is not None:
|
|
lr_multiplier = _calc_norm_ratio(
|
|
model,
|
|
params,
|
|
'norm_clipped_grad_update',
|
|
param_to_device,
|
|
max_gradient_norm,
|
|
)
|
|
|
|
if allow_lr_injection:
|
|
if not model.net.BlobIsDefined(_LEARNING_RATE_INJECTION):
|
|
lr_injection = model.param_init_net.ConstantFill(
|
|
[],
|
|
_LEARNING_RATE_INJECTION,
|
|
shape=[1],
|
|
value=1.0,
|
|
)
|
|
else:
|
|
lr_injection = _LEARNING_RATE_INJECTION
|
|
|
|
if lr_multiplier is None:
|
|
lr_multiplier = lr_injection
|
|
else:
|
|
lr_multiplier = model.net.Mul(
|
|
[lr_multiplier, lr_injection],
|
|
'lr_multiplier',
|
|
broadcast=1,
|
|
)
|
|
optimizer.add_lr_multiplier(lr_multiplier)
|
|
|
|
for param_info in params:
|
|
param_name = str(param_info.blob)
|
|
|
|
device = get_param_device(param_name, param_info.grad, param_to_device)
|
|
|
|
with core.DeviceScope(device):
|
|
if param_info.optimizer and use_param_info_optim:
|
|
param_info.optimizer(model.net, model.param_init_net, param_info)
|
|
else:
|
|
optimizer(model.net, model.param_init_net, param_info)
|
|
return optimizer
|
|
|
|
|
|
def add_weight_decay(model, weight_decay):
|
|
"""Adds a decay to weights in the model.
|
|
|
|
This is a form of L2 regularization.
|
|
|
|
Args:
|
|
weight_decay: strength of the regularization
|
|
"""
|
|
_build(
|
|
model,
|
|
WeightDecayBuilder(weight_decay=weight_decay),
|
|
weights_only=True,
|
|
use_param_info_optim=False,
|
|
)
|
|
|
|
|
|
def build_sgd(
|
|
model,
|
|
base_learning_rate,
|
|
max_gradient_norm=None,
|
|
allow_lr_injection=False,
|
|
**kwargs
|
|
):
|
|
sgd_optimizer = SgdOptimizer(base_learning_rate, **kwargs)
|
|
return _build(
|
|
model,
|
|
sgd_optimizer,
|
|
max_gradient_norm=max_gradient_norm,
|
|
allow_lr_injection=allow_lr_injection,
|
|
)
|
|
|
|
|
|
def build_multi_precision_sgd(
|
|
model,
|
|
base_learning_rate,
|
|
max_gradient_norm=None,
|
|
allow_lr_injection=False,
|
|
**kwargs
|
|
):
|
|
multi_prec_sgd_optimizer = MultiPrecisionSgdOptimizer(
|
|
base_learning_rate, **kwargs
|
|
)
|
|
return _build(
|
|
model,
|
|
multi_prec_sgd_optimizer,
|
|
max_gradient_norm=max_gradient_norm,
|
|
allow_lr_injection=allow_lr_injection,
|
|
)
|
|
|
|
|
|
def build_fp16_sgd(model, base_learning_rate, **kwargs):
|
|
fp16_sgd_optimizer = FP16SgdOptimizer(
|
|
base_learning_rate, **kwargs
|
|
)
|
|
return _build(model, fp16_sgd_optimizer)
|
|
|
|
|
|
def build_ftrl(model, engine="SIMD", **kwargs):
|
|
if engine == "SIMD":
|
|
assert core.IsOperator('Ftrl_ENGINE_SIMD')
|
|
assert core.IsOperator('SparseFtrl_ENGINE_SIMD')
|
|
ftrl_optimizer = FtrlOptimizer(engine=engine, **kwargs)
|
|
return _build(model, ftrl_optimizer)
|
|
|
|
|
|
def build_adagrad(
|
|
model,
|
|
base_learning_rate,
|
|
parameters=None,
|
|
max_gradient_norm=None,
|
|
allow_lr_injection=False,
|
|
**kwargs
|
|
):
|
|
adagrad_optimizer = AdagradOptimizer(alpha=base_learning_rate, **kwargs)
|
|
return _build(
|
|
model,
|
|
adagrad_optimizer,
|
|
max_gradient_norm=max_gradient_norm,
|
|
allow_lr_injection=allow_lr_injection,
|
|
)
|
|
|
|
|
|
def build_adam(
|
|
model,
|
|
base_learning_rate,
|
|
max_gradient_norm=None,
|
|
allow_lr_injection=False,
|
|
**kwargs
|
|
):
|
|
adam_optimizer = AdamOptimizer(alpha=base_learning_rate, **kwargs)
|
|
return _build(
|
|
model,
|
|
adam_optimizer,
|
|
max_gradient_norm=max_gradient_norm,
|
|
allow_lr_injection=allow_lr_injection,
|
|
)
|
|
|
|
|
|
def build_yellowfin(model, base_learning_rate=0.1, **kwargs):
|
|
yellowfin_optimizer = YellowFinOptimizer(
|
|
alpha=base_learning_rate,
|
|
**kwargs)
|
|
return _build(model, yellowfin_optimizer)
|
|
|
|
|
|
def build_rms_prop(
|
|
model,
|
|
base_learning_rate,
|
|
max_gradient_norm=None,
|
|
allow_lr_injection=False,
|
|
**kwargs
|
|
):
|
|
rms_prop_optimizer = RmsPropOptimizer(alpha=base_learning_rate, **kwargs)
|
|
return _build(
|
|
model,
|
|
rms_prop_optimizer,
|
|
max_gradient_norm=max_gradient_norm,
|
|
allow_lr_injection=allow_lr_injection,
|
|
)
|