mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Follows #137407 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137459 Approved by: https://github.com/Skylion007
654 lines
22 KiB
C++
654 lines
22 KiB
C++
#include <torch/optim/lbfgs.h>
|
|
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/serialize/archive.h>
|
|
#include <torch/utils.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <functional>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace optim {
|
|
|
|
LBFGSOptions::LBFGSOptions(double lr) : lr_(lr) {}
|
|
|
|
bool operator==(const LBFGSOptions& lhs, const LBFGSOptions& rhs) {
|
|
return (lhs.lr() == rhs.lr()) && (lhs.max_iter() == rhs.max_iter()) &&
|
|
(lhs.max_eval() == rhs.max_eval()) &&
|
|
(lhs.tolerance_grad() == rhs.tolerance_grad()) &&
|
|
(lhs.tolerance_change() == rhs.tolerance_change() &&
|
|
(lhs.history_size() == rhs.history_size())) &&
|
|
(lhs.line_search_fn() == rhs.line_search_fn());
|
|
}
|
|
|
|
void LBFGSOptions::serialize(torch::serialize::OutputArchive& archive) const {
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_iter);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_eval);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(tolerance_grad);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(tolerance_change);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(history_size);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(line_search_fn);
|
|
}
|
|
|
|
void LBFGSOptions::serialize(torch::serialize::InputArchive& archive) {
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, max_iter);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(int64_t, max_eval);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, tolerance_grad);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, tolerance_change);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, history_size);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(std::string, line_search_fn);
|
|
}
|
|
|
|
double LBFGSOptions::get_lr() const {
|
|
return lr();
|
|
}
|
|
|
|
void LBFGSOptions::set_lr(const double lr) {
|
|
this->lr(lr);
|
|
}
|
|
|
|
template <typename T>
|
|
bool if_container_equal(T lhs, T rhs) {
|
|
if (!(lhs.size() == rhs.size()))
|
|
return false;
|
|
for (const auto i : c10::irange(lhs.size())) {
|
|
if (!torch::equal(lhs.at(i), rhs.at(i)))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool operator==(const LBFGSParamState& lhs, const LBFGSParamState& rhs) {
|
|
auto isNull = [](const std::optional<std::vector<Tensor>>& val) {
|
|
return val == std::nullopt;
|
|
};
|
|
return (lhs.func_evals() == rhs.func_evals()) &&
|
|
(lhs.n_iter() == rhs.n_iter()) && (lhs.t() == rhs.t()) &&
|
|
(lhs.prev_loss() == rhs.prev_loss()) &&
|
|
torch::equal_if_defined(lhs.d(), rhs.d()) &&
|
|
torch::equal_if_defined(lhs.H_diag(), rhs.H_diag()) &&
|
|
torch::equal_if_defined(lhs.prev_flat_grad(), rhs.prev_flat_grad()) &&
|
|
if_container_equal(lhs.old_dirs(), rhs.old_dirs()) &&
|
|
if_container_equal(lhs.old_stps(), rhs.old_stps()) &&
|
|
if_container_equal(lhs.ro(), rhs.ro()) &&
|
|
((isNull(lhs.al()) && isNull(rhs.al())) ||
|
|
(!isNull(lhs.al()) && !isNull(rhs.al()) &&
|
|
if_container_equal(*lhs.al(), *rhs.al())));
|
|
}
|
|
|
|
void LBFGSParamState::serialize(
|
|
torch::serialize::OutputArchive& archive) const {
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(func_evals);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(n_iter);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(t);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(prev_loss);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(d);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(H_diag);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(prev_flat_grad);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_dirs);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_stps);
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(ro);
|
|
// Python version only serializes state vars if explicitly defined
|
|
if (al() != std::nullopt) {
|
|
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(al);
|
|
}
|
|
}
|
|
|
|
void LBFGSParamState::serialize(torch::serialize::InputArchive& archive) {
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, func_evals);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, n_iter);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, t);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, prev_loss);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, d);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, H_diag);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, prev_flat_grad);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, old_dirs);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, old_stps);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, ro);
|
|
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(std::vector<Tensor>, al);
|
|
}
|
|
|
|
Tensor LBFGS::_gather_flat_grad() {
|
|
std::vector<Tensor> views;
|
|
for (const auto& p : param_groups_.at(0).params()) {
|
|
if (!p.grad().defined()) {
|
|
views.emplace_back(p.new_empty({p.numel()}).zero_());
|
|
} else if (p.grad().is_sparse()) {
|
|
views.emplace_back(p.grad().to_dense().view(-1));
|
|
} else {
|
|
views.emplace_back(p.grad().view(-1));
|
|
}
|
|
}
|
|
return torch::cat(views, 0);
|
|
}
|
|
|
|
int64_t LBFGS::_numel() {
|
|
if (_numel_cache == std::nullopt) {
|
|
auto res = 0;
|
|
for (const auto& p : param_groups_.at(0).params()) {
|
|
res += p.numel();
|
|
}
|
|
_numel_cache = res;
|
|
}
|
|
return *_numel_cache;
|
|
}
|
|
|
|
void LBFGS::_add_grad(const double step_size, const Tensor& update) {
|
|
auto offset = 0;
|
|
for (auto& p : param_groups_.at(0).params()) {
|
|
auto numel = p.numel();
|
|
// view as to avoid deprecated pointwise semantics
|
|
p.add_(
|
|
update.index({at::indexing::Slice(offset, offset + numel)}).view_as(p),
|
|
step_size);
|
|
offset += numel;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(offset == _numel());
|
|
}
|
|
|
|
void LBFGS::_set_param(const std::vector<Tensor>& params_data) {
|
|
auto& _params = param_groups_.at(0).params();
|
|
TORCH_INTERNAL_ASSERT(params_data.size() == _params.size());
|
|
for (const auto i : c10::irange(_params.size())) {
|
|
_params.at(i).copy_(params_data.at(i));
|
|
}
|
|
}
|
|
|
|
std::vector<Tensor> LBFGS::_clone_param() {
|
|
std::vector<Tensor> result;
|
|
for (const auto& p : param_groups_.at(0).params()) {
|
|
result.emplace_back(p.clone(at::MemoryFormat::Contiguous));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
std::tuple<double, Tensor> LBFGS::_directional_evaluate(
|
|
const LossClosure& closure,
|
|
const std::vector<Tensor>& x,
|
|
double t,
|
|
const Tensor& d) {
|
|
_add_grad(t, d);
|
|
double loss = 0;
|
|
{
|
|
torch::AutoGradMode enable_grad(true);
|
|
loss = closure().item<double>();
|
|
}
|
|
auto flat_grad = _gather_flat_grad();
|
|
_set_param(x);
|
|
return std::make_tuple(loss, flat_grad);
|
|
}
|
|
|
|
static double _cubic_interpolate(
|
|
double x1,
|
|
double f1,
|
|
double g1,
|
|
double x2,
|
|
double f2,
|
|
double g2,
|
|
std::optional<std::tuple<double, double>> bounds = std::nullopt) {
|
|
// ported from https://github.com/torch/optim/blob/master/polyinterp.lua
|
|
// Compute bounds of interpolation area
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
double xmin_bound, xmax_bound;
|
|
if (bounds != std::nullopt) {
|
|
std::tie(xmin_bound, xmax_bound) = *bounds;
|
|
} else {
|
|
std::tie(xmin_bound, xmax_bound) =
|
|
(x1 <= x2) ? std::make_tuple(x1, x2) : std::make_tuple(x2, x1);
|
|
}
|
|
// Code for most common case: cubic interpolation of 2 points
|
|
// w/ function and derivative values for both
|
|
// Solution in this case (where x2 is the farthest point):
|
|
// d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
|
|
// d2 = sqrt(d1^2 - g1*g2);
|
|
// min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
|
|
// t_new = min(max(min_pos,xmin_bound),xmax_bound);
|
|
|
|
auto d1 = (g1 + g2) - (3 * (f1 - f2) / (x1 - x2));
|
|
auto d2_square = std::pow(d1, 2) - g1 * g2;
|
|
if (d2_square >= 0) {
|
|
auto d2 = std::sqrt(d2_square);
|
|
double min_pos = 0;
|
|
if (x1 <= x2) {
|
|
min_pos = x2 - ((x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)));
|
|
} else {
|
|
min_pos = x1 - ((x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)));
|
|
}
|
|
return std::min(std::max(min_pos, xmin_bound), xmax_bound);
|
|
} else {
|
|
return (xmin_bound + xmax_bound) / 2;
|
|
}
|
|
}
|
|
|
|
using Function = std::function<std::tuple<double, Tensor>(
|
|
const std::vector<Tensor>& x,
|
|
double t,
|
|
const Tensor& d)>;
|
|
static std::tuple<double, Tensor, double, int64_t> _strong_wolfe(
|
|
const Function& obj_func,
|
|
const std::vector<Tensor>& x,
|
|
double t,
|
|
const Tensor& d,
|
|
double f,
|
|
Tensor g,
|
|
const Tensor& gtd,
|
|
double c1 = 1e-4,
|
|
double c2 = 0.9, // // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
double tolerance_change = 1e-9,
|
|
double max_ls = 25) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
|
|
auto val = [](const Tensor& t) { return t.item<double>(); };
|
|
|
|
auto d_norm = val(d.abs().max());
|
|
g = g.clone(at::MemoryFormat::Contiguous);
|
|
// evaluate objective and gradient using initial step
|
|
auto [f_new, g_new] = obj_func(x, t, d);
|
|
int64_t ls_func_evals = 1;
|
|
auto gtd_new = g_new.dot(d);
|
|
|
|
// bracket an interval containing a point satisfying the Wolfe criteria
|
|
double t_prev = 0;
|
|
auto f_prev = f;
|
|
auto g_prev = g;
|
|
auto gtd_prev = gtd;
|
|
bool done = false;
|
|
auto ls_iter = 0;
|
|
std::vector<double> bracket, bracket_f;
|
|
std::vector<Tensor> bracket_g, bracket_gtd;
|
|
|
|
while (ls_iter < max_ls) {
|
|
// check conditions
|
|
if ((f_new > (f + c1 * t * val(gtd))) ||
|
|
(ls_iter > 1 && (f_new >= f_prev))) {
|
|
bracket = {t_prev, t};
|
|
bracket_f = {f_prev, f_new};
|
|
bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)};
|
|
bracket_gtd = {gtd_prev, gtd_new};
|
|
break;
|
|
}
|
|
if (std::abs(val(gtd_new)) <= (-c2 * val(gtd))) {
|
|
bracket = {t, t};
|
|
bracket_f = {f_new, f_new};
|
|
bracket_g = {g_new, g_new};
|
|
done = true;
|
|
break;
|
|
}
|
|
if (val(gtd_new) >= 0) {
|
|
bracket = {t_prev, t};
|
|
bracket_f = {f_prev, f_new};
|
|
bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)};
|
|
bracket_gtd = {gtd_prev, gtd_new};
|
|
break;
|
|
}
|
|
// interpolate
|
|
auto min_step = t +
|
|
0.01 * (t - t_prev); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
auto max_step = t * 10; // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
auto tmp = t;
|
|
t = _cubic_interpolate(
|
|
t_prev,
|
|
f_prev,
|
|
val(gtd_prev),
|
|
t,
|
|
f_new,
|
|
val(gtd_new),
|
|
std::make_tuple(min_step, max_step));
|
|
// next step
|
|
t_prev = tmp;
|
|
f_prev = f_new;
|
|
g_prev = g_new.clone(at::MemoryFormat::Contiguous);
|
|
gtd_prev = gtd_new;
|
|
std::tie(f_new, g_new) = obj_func(x, t, d);
|
|
ls_func_evals += 1;
|
|
gtd_new = g_new.dot(d);
|
|
ls_iter += 1;
|
|
}
|
|
// reached max number of iterations?
|
|
if (ls_iter == max_ls) {
|
|
bracket = {0, t};
|
|
bracket_f = {f, f_new};
|
|
bracket_g = {g, g_new};
|
|
}
|
|
|
|
// zoom phase: we now have a point satisfying the criteria, or
|
|
// a bracket around it. We refine the bracket until we find the
|
|
// exact point satisfying the criteria
|
|
bool insuf_progress = false;
|
|
// find high and low points in bracket
|
|
auto [low_pos, high_pos] = bracket_f[0] <= bracket_f[1]
|
|
? std::make_tuple(0, 1)
|
|
: std::make_tuple(1, 0);
|
|
while (!done && (ls_iter < max_ls)) {
|
|
// compute new trial value
|
|
t = _cubic_interpolate(
|
|
bracket[0],
|
|
bracket_f[0],
|
|
val(bracket_gtd[0]),
|
|
bracket[1],
|
|
bracket_f[1],
|
|
val(bracket_gtd[1]));
|
|
|
|
// test that we are making sufficient progress:
|
|
// in case `t` is so close to boundary, we mark that we are making
|
|
// insufficient progress, and if
|
|
// + we have made insufficient progress in the last step, or
|
|
// + `t` is at one of the boundary,
|
|
// we will move `t` to a position which is `0.1 * len(bracket)`
|
|
// away from the nearest boundary point.
|
|
double bracket_max = std::max(bracket[0], bracket[1]);
|
|
auto bracket_min = std::min(bracket[0], bracket[1]);
|
|
auto eps = 0.1 *
|
|
(bracket_max -
|
|
bracket_min); // // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
if (std::min(bracket_max - t, t - bracket_min) < eps) {
|
|
// interpolation close to boundary
|
|
if (insuf_progress || (t >= bracket_max) || (t <= bracket_min)) {
|
|
// evaluate at 0.1 away from boundary
|
|
t = (std::abs(t - bracket_max) < std::abs(t - bracket_min))
|
|
? bracket_max - eps
|
|
: bracket_min + eps;
|
|
insuf_progress = false;
|
|
} else {
|
|
insuf_progress = true;
|
|
}
|
|
} else {
|
|
insuf_progress = false;
|
|
}
|
|
|
|
// Evaluate new point
|
|
std::tie(f_new, g_new) = obj_func(x, t, d);
|
|
ls_func_evals += 1;
|
|
gtd_new = g_new.dot(d);
|
|
ls_iter += 1;
|
|
|
|
if ((f_new > (f + c1 * t * val(gtd))) || (f_new >= bracket_f[low_pos])) {
|
|
// Armijo condition not satisfied or not lower than lowest point
|
|
// # Armijo condition not satisfied or not lower than lowest point
|
|
bracket[high_pos] = t;
|
|
bracket_f[high_pos] = f_new;
|
|
bracket_g[high_pos] = g_new.clone(at::MemoryFormat::Contiguous);
|
|
bracket_gtd[high_pos] = gtd_new;
|
|
std::tie(low_pos, high_pos) = bracket_f[0] <= bracket_f[1]
|
|
? std::make_tuple(0, 1)
|
|
: std::make_tuple(1, 0);
|
|
} else {
|
|
if (val(at::abs(gtd_new)) <= (-c2 * val(gtd))) {
|
|
// Wolfe conditions satisfied
|
|
done = true;
|
|
} else if ((val(gtd_new) * (bracket[high_pos] - bracket[low_pos])) >= 0) {
|
|
// old high becomes new low
|
|
bracket[high_pos] = bracket[low_pos];
|
|
bracket_f[high_pos] = bracket_f[low_pos];
|
|
bracket_g[high_pos] = bracket_g[low_pos];
|
|
bracket_gtd[high_pos] = bracket_gtd[low_pos];
|
|
}
|
|
|
|
// new point becomes new low
|
|
bracket[low_pos] = t;
|
|
bracket_f[low_pos] = f_new;
|
|
bracket_g[low_pos] = g_new.clone(at::MemoryFormat::Contiguous);
|
|
bracket_gtd[low_pos] = gtd_new;
|
|
}
|
|
|
|
// line-search bracket is so small
|
|
if ((std::abs(bracket[1] - bracket[0]) * d_norm) < tolerance_change)
|
|
break;
|
|
}
|
|
|
|
// return stuff
|
|
t = bracket[low_pos];
|
|
f_new = bracket_f[low_pos];
|
|
g_new = bracket_g[low_pos];
|
|
return std::make_tuple(f_new, g_new, t, ls_func_evals);
|
|
}
|
|
|
|
Tensor LBFGS::step(LossClosure closure) {
|
|
NoGradGuard no_grad;
|
|
TORCH_CHECK(closure != nullptr, "LBFGS requires a closure function");
|
|
TORCH_INTERNAL_ASSERT(param_groups_.size() == 1);
|
|
auto val = [](const Tensor& t) { return t.item<double>(); };
|
|
|
|
auto& group = param_groups_.at(0);
|
|
auto& _params = group.params();
|
|
const auto& options = static_cast<const LBFGSOptions&>(group.options());
|
|
auto lr = options.lr();
|
|
auto max_iter = options.max_iter();
|
|
auto max_eval = options.max_eval();
|
|
auto tolerance_grad = options.tolerance_grad();
|
|
auto tolerance_change = options.tolerance_change();
|
|
auto line_search_fn = options.line_search_fn();
|
|
auto history_size = options.history_size();
|
|
|
|
// NOTE: LBFGS has only global state, but we register it as state for
|
|
// the first param, because this helps with casting in load_state_dict
|
|
auto param_state = state_.find(_params.at(0).unsafeGetTensorImpl());
|
|
if (param_state == state_.end()) {
|
|
state_[_params.at(0).unsafeGetTensorImpl()] =
|
|
std::make_unique<LBFGSParamState>();
|
|
}
|
|
auto& state = static_cast<LBFGSParamState&>(
|
|
*state_[_params.at(0).unsafeGetTensorImpl()]);
|
|
// evaluate initial f(x) and df/dx
|
|
Tensor orig_loss;
|
|
{
|
|
torch::AutoGradMode enable_grad(true);
|
|
orig_loss = closure();
|
|
}
|
|
|
|
auto loss = val(orig_loss);
|
|
auto current_evals = 1;
|
|
state.func_evals(state.func_evals() + 1);
|
|
auto flat_grad = _gather_flat_grad();
|
|
auto opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad);
|
|
|
|
// optimal condition
|
|
if (opt_cond) {
|
|
return orig_loss;
|
|
}
|
|
|
|
// tensors cached in state (for tracing)
|
|
auto& d = state.d();
|
|
auto& t = state.t();
|
|
auto& old_dirs = state.old_dirs();
|
|
auto& old_stps = state.old_stps();
|
|
auto& ro = state.ro();
|
|
auto& H_diag = state.H_diag();
|
|
auto& prev_flat_grad = state.prev_flat_grad();
|
|
auto& prev_loss = state.prev_loss();
|
|
|
|
int n_iter = 0;
|
|
|
|
// optimize for a max of max_iter iterations
|
|
while (n_iter < max_iter) {
|
|
// keep track of nb of iterations
|
|
n_iter += 1;
|
|
state.n_iter(state.n_iter() + 1);
|
|
|
|
// compute gradient descent direction
|
|
if (state.n_iter() == 1) {
|
|
d = flat_grad.neg();
|
|
H_diag = torch::tensor(1);
|
|
old_dirs = {};
|
|
old_stps = {};
|
|
ro = {};
|
|
} else {
|
|
// do lbfgs update (update memory)
|
|
auto y = flat_grad.sub(prev_flat_grad);
|
|
auto s = d.mul(t);
|
|
auto ys = y.dot(s); // y*s
|
|
if (val(ys) > 1e-10) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
// updating memory
|
|
if (static_cast<int64_t>(old_dirs.size()) == history_size) {
|
|
// shift history by one (limited-memory)
|
|
old_dirs.pop_front();
|
|
old_stps.pop_front();
|
|
ro.pop_front();
|
|
}
|
|
// store new direction/step
|
|
old_dirs.emplace_back(y);
|
|
old_stps.emplace_back(s);
|
|
ro.emplace_back(1. / ys);
|
|
|
|
// update scale of initial Hessian approximation
|
|
H_diag = ys / y.dot(y); // (y*y)
|
|
}
|
|
|
|
// compute the approximate (L-BFGS) inverse Hessian
|
|
// multiplied by the gradient
|
|
int64_t num_old = static_cast<int64_t>(old_dirs.size());
|
|
|
|
if (state.al() == std::nullopt) {
|
|
state.al(std::vector<Tensor>(history_size));
|
|
}
|
|
auto& al = state.al();
|
|
|
|
// iteration in L-BFGS loop collapsed to use just one buffer
|
|
auto q = flat_grad.neg();
|
|
for (int64_t i = num_old - 1; i > -1; i--) {
|
|
(*al).at(i) = old_stps.at(i).dot(q) * ro.at(i);
|
|
q.add_(old_dirs.at(i), -val((*al).at(i)));
|
|
}
|
|
|
|
// multiply by initial Hessian
|
|
// r/d is the final direction
|
|
auto r = torch::mul(q, H_diag);
|
|
d = r;
|
|
for (const auto i : c10::irange(num_old)) {
|
|
auto be_i = old_dirs.at(i).dot(r) * ro.at(i);
|
|
r.add_(old_stps.at(i), val((*al).at(i) - be_i));
|
|
}
|
|
}
|
|
|
|
if (!prev_flat_grad.defined()) {
|
|
prev_flat_grad = flat_grad.clone(at::MemoryFormat::Contiguous);
|
|
} else {
|
|
prev_flat_grad.copy_(flat_grad);
|
|
}
|
|
prev_loss = loss;
|
|
|
|
// ############################################################
|
|
// # compute step length
|
|
// ############################################################
|
|
// reset initial guess for step size
|
|
if (state.n_iter() == 1) {
|
|
t = std::min(1., 1. / val(flat_grad.abs().sum())) * lr;
|
|
} else {
|
|
t = lr;
|
|
}
|
|
|
|
// directional derivative
|
|
auto gtd = flat_grad.dot(d); // g * d
|
|
|
|
// directional derivative is below tolerance
|
|
if (val(gtd) > -tolerance_change)
|
|
break;
|
|
|
|
// optional line search: user function
|
|
auto ls_func_evals = 0;
|
|
if (line_search_fn != std::nullopt) {
|
|
TORCH_CHECK(
|
|
*line_search_fn == "strong_wolfe",
|
|
"only 'strong_wolfe' is supported");
|
|
auto x_init = _clone_param();
|
|
auto obj_func =
|
|
[&](const std::vector<Tensor>& x, double t, const Tensor& d) {
|
|
return _directional_evaluate(closure, x, t, d);
|
|
};
|
|
std::tie(loss, flat_grad, t, ls_func_evals) =
|
|
_strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd);
|
|
_add_grad(t, d);
|
|
opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad);
|
|
} else {
|
|
// no line search, simply move with fixed-step
|
|
_add_grad(t, d);
|
|
if (n_iter != max_iter) {
|
|
// re-evaluate function only if not in last iteration
|
|
// the reason we do this: in a stochastic setting,
|
|
// no use to re-evaluate that function here
|
|
{
|
|
torch::AutoGradMode enable_grad(true);
|
|
loss = val(closure());
|
|
}
|
|
flat_grad = _gather_flat_grad();
|
|
opt_cond = val(torch::max(flat_grad.abs())) <= tolerance_grad;
|
|
ls_func_evals = 1;
|
|
}
|
|
}
|
|
// update func eval
|
|
current_evals += ls_func_evals;
|
|
state.func_evals(state.func_evals() + ls_func_evals);
|
|
|
|
// ############################################################
|
|
// # check conditions
|
|
// ############################################################
|
|
if (n_iter == max_iter)
|
|
break;
|
|
|
|
if (current_evals >= *max_eval)
|
|
break;
|
|
|
|
// optimal condition
|
|
if (opt_cond)
|
|
break;
|
|
|
|
// lack of progress
|
|
if (val(d.mul(t).abs().max()) <= tolerance_change)
|
|
break;
|
|
|
|
if (std::abs(loss - prev_loss) < tolerance_change)
|
|
break;
|
|
}
|
|
|
|
return orig_loss;
|
|
}
|
|
|
|
void LBFGS::save(serialize::OutputArchive& archive) const {
|
|
serialize(*this, archive);
|
|
}
|
|
|
|
void LBFGS::load(serialize::InputArchive& archive) {
|
|
IValue pytorch_version;
|
|
if (archive.try_read("pytorch_version", pytorch_version)) {
|
|
serialize(*this, archive);
|
|
} else { // deserializing archives saved in old format (prior to
|
|
// version 1.5.0)
|
|
TORCH_WARN(
|
|
"Your serialized LBFGS optimizer is still using the old serialization format. "
|
|
"The func_evals and n_iter value in state will be set to 0, ro will be set to an empty deque "
|
|
"and al will be set to std::nullopt because the old LBFGS optimizer didn't save these values."
|
|
"You should re-save your LBFGS optimizer to use the new serialization format.");
|
|
Tensor d, t, H_diag, prev_flat_grad, prev_loss;
|
|
std::deque<Tensor> old_dirs, old_stps;
|
|
archive("d", d, /*is_buffer=*/true);
|
|
archive("t", t, /*is_buffer=*/true);
|
|
archive("H_diag", H_diag, /*is_buffer=*/true);
|
|
archive("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
|
|
archive("prev_loss", prev_loss, /*is_buffer=*/true);
|
|
torch::optim::serialize(archive, "old_dirs", old_dirs);
|
|
torch::optim::serialize(archive, "old_stps", old_stps);
|
|
|
|
// NOTE: LBFGS has only global state, but we register it as state for
|
|
// the first param, because this helps with casting in load_state_dict
|
|
auto state = std::make_unique<LBFGSParamState>();
|
|
state->d(d);
|
|
state->t(t.item<double>());
|
|
state->H_diag(H_diag);
|
|
state->prev_flat_grad(prev_flat_grad);
|
|
state->prev_loss(prev_loss.item<double>());
|
|
state->old_dirs(old_dirs);
|
|
state->old_stps(old_stps);
|
|
state_[param_groups_.at(0).params().at(0).unsafeGetTensorImpl()] =
|
|
std::move(state);
|
|
}
|
|
}
|
|
} // namespace optim
|
|
} // namespace torch
|