mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
161 lines
5.6 KiB
Python
161 lines
5.6 KiB
Python
# @package adaptive_weight
|
|
# Module caffe2.fb.python.layers.adaptive_weight
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import numpy as np
|
|
from caffe2.python import core, schema
|
|
from caffe2.python.layers.layers import ModelLayer
|
|
from caffe2.python.regularizer import BoundedGradientProjection, LogBarrier
|
|
|
|
|
|
"""
|
|
Implementation of adaptive weighting: https://arxiv.org/pdf/1705.07115.pdf
|
|
"""
|
|
|
|
|
|
class AdaptiveWeight(ModelLayer):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
input_record,
|
|
name="adaptive_weight",
|
|
optimizer=None,
|
|
weights=None,
|
|
enable_diagnose=False,
|
|
estimation_method="log_std",
|
|
pos_optim_method="log_barrier",
|
|
reg_lambda=0.1,
|
|
**kwargs
|
|
):
|
|
super(AdaptiveWeight, self).__init__(model, name, input_record, **kwargs)
|
|
self.output_schema = schema.Scalar(
|
|
np.float32, self.get_next_blob_reference("adaptive_weight")
|
|
)
|
|
self.data = self.input_record.field_blobs()
|
|
self.num = len(self.data)
|
|
self.optimizer = optimizer
|
|
if weights is not None:
|
|
assert len(weights) == self.num
|
|
else:
|
|
weights = [1. / self.num for _ in range(self.num)]
|
|
assert min(weights) > 0, "initial weights must be positive"
|
|
self.weights = np.array(weights).astype(np.float32)
|
|
self.estimation_method = str(estimation_method).lower()
|
|
# used in positivity-constrained parameterization as when the estimation method
|
|
# is inv_var, with optimization method being either log barrier, or grad proj
|
|
self.pos_optim_method = str(pos_optim_method).lower()
|
|
self.reg_lambda = float(reg_lambda)
|
|
self.enable_diagnose = enable_diagnose
|
|
self.init_func = getattr(self, self.estimation_method + "_init")
|
|
self.weight_func = getattr(self, self.estimation_method + "_weight")
|
|
self.reg_func = getattr(self, self.estimation_method + "_reg")
|
|
self.init_func()
|
|
if self.enable_diagnose:
|
|
self.weight_i = [
|
|
self.get_next_blob_reference("adaptive_weight_%d" % i)
|
|
for i in range(self.num)
|
|
]
|
|
for i in range(self.num):
|
|
self.model.add_ad_hoc_plot_blob(self.weight_i[i])
|
|
|
|
def concat_data(self, net):
|
|
reshaped = [net.NextScopedBlob("reshaped_data_%d" % i) for i in range(self.num)]
|
|
# coerce shape for single real values
|
|
for i in range(self.num):
|
|
net.Reshape(
|
|
[self.data[i]],
|
|
[reshaped[i], net.NextScopedBlob("new_shape_%d" % i)],
|
|
shape=[1],
|
|
)
|
|
concated = net.NextScopedBlob("concated_data")
|
|
net.Concat(
|
|
reshaped, [concated, net.NextScopedBlob("concated_new_shape")], axis=0
|
|
)
|
|
return concated
|
|
|
|
def log_std_init(self):
|
|
"""
|
|
mu = 2 log sigma, sigma = standard variance
|
|
per task objective:
|
|
min 1 / 2 / e^mu X + mu / 2
|
|
"""
|
|
values = np.log(1. / 2. / self.weights)
|
|
initializer = (
|
|
"GivenTensorFill",
|
|
{"values": values, "dtype": core.DataType.FLOAT},
|
|
)
|
|
self.mu = self.create_param(
|
|
param_name="mu",
|
|
shape=[self.num],
|
|
initializer=initializer,
|
|
optimizer=self.optimizer,
|
|
)
|
|
|
|
def log_std_weight(self, x, net, weight):
|
|
"""
|
|
min 1 / 2 / e^mu X + mu / 2
|
|
"""
|
|
mu_neg = net.NextScopedBlob("mu_neg")
|
|
net.Negative(self.mu, mu_neg)
|
|
mu_neg_exp = net.NextScopedBlob("mu_neg_exp")
|
|
net.Exp(mu_neg, mu_neg_exp)
|
|
net.Scale(mu_neg_exp, weight, scale=0.5)
|
|
|
|
def log_std_reg(self, net, reg):
|
|
net.Scale(self.mu, reg, scale=0.5)
|
|
|
|
def inv_var_init(self):
|
|
"""
|
|
k = 1 / variance
|
|
per task objective:
|
|
min 1 / 2 * k X - 1 / 2 * log k
|
|
"""
|
|
values = 2. * self.weights
|
|
initializer = (
|
|
"GivenTensorFill",
|
|
{"values": values, "dtype": core.DataType.FLOAT},
|
|
)
|
|
if self.pos_optim_method == "log_barrier":
|
|
regularizer = LogBarrier(reg_lambda=self.reg_lambda)
|
|
elif self.pos_optim_method == "pos_grad_proj":
|
|
regularizer = BoundedGradientProjection(lb=0, left_open=True)
|
|
else:
|
|
raise TypeError(
|
|
"unknown positivity optimization method: {}".format(
|
|
self.pos_optim_method
|
|
)
|
|
)
|
|
self.k = self.create_param(
|
|
param_name="k",
|
|
shape=[self.num],
|
|
initializer=initializer,
|
|
optimizer=self.optimizer,
|
|
regularizer=regularizer,
|
|
)
|
|
|
|
def inv_var_weight(self, x, net, weight):
|
|
net.Scale(self.k, weight, scale=0.5)
|
|
|
|
def inv_var_reg(self, net, reg):
|
|
log_k = net.NextScopedBlob("log_k")
|
|
net.Log(self.k, log_k)
|
|
net.Scale(log_k, reg, scale=-0.5)
|
|
|
|
def _add_ops_impl(self, net, enable_diagnose):
|
|
x = self.concat_data(net)
|
|
weight = net.NextScopedBlob("weight")
|
|
reg = net.NextScopedBlob("reg")
|
|
weighted_x = net.NextScopedBlob("weighted_x")
|
|
weighted_x_add_reg = net.NextScopedBlob("weighted_x_add_reg")
|
|
self.weight_func(x, net, weight)
|
|
self.reg_func(net, reg)
|
|
net.Mul([weight, x], weighted_x)
|
|
net.Add([weighted_x, reg], weighted_x_add_reg)
|
|
net.SumElements(weighted_x_add_reg, self.output_schema())
|
|
if enable_diagnose:
|
|
for i in range(self.num):
|
|
net.Slice(weight, self.weight_i[i], starts=[i], ends=[i + 1])
|
|
|
|
def add_ops(self, net):
|
|
self._add_ops_impl(net, self.enable_diagnose)
|