pytorch/caffe2/perfkernels/adagrad.cc
Taiqing Wang 8cb1f2f9dc implement L2 regularization for Adagrad in caffe2 and dper (#37705)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37705

Pull Request resolved: https://github.com/pytorch/pytorch/pull/37372

Posted note: [Regularizing SparseNN Against Over-fitting](https://fb.workplace.com/notes/taiqing-wang/regularizing-sparsenn-against-over-fitting/220306075902708/)

**Problem formulation**

L(w) = J(w) + lambda/2 * ||w||^2
J(w) is the empirical loss, and ||w||^2 is the squared L2 norm of the parameters, a.k.a. L2 regularizer.

dL(w)/ dw_i = dJ(w)/dw_i + lambda w_i
dL(w)/ dw_i is the gradient of L(w) w.r.t. w_i.

To implement the L2 regularizer, the gradient of J(w) w.r.t. w_i is added with w_i. lambda is called as weight decay in this implementation.

**Code changes**
* In the initialization method of AdagradOptimizer, a new input argument, weight_decay, is added.
* In the _run function of AdagradOptimizer, the weight decay will be skipped for 1d bias vectors.
* In the parameter update functions of Adagrad, the gradient is updated by weight_decay * w_i. The default value for weight_decay is zero.

Test Plan:
`
buck build caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test_weight_decay
`

`
./buck-out/gen/caffe2/caffe2/fb/dper/layer_models/tests/split_1/sparse_nn_test_weight_decay#binary.par
`

Reviewed By: jspark1105

Differential Revision: D21258652

fbshipit-source-id: d2366ddcd736a03205a2d16f914703b16d9fce8f
2020-05-03 10:42:49 -07:00

185 lines
3.5 KiB
C++

#include "caffe2/perfkernels/adagrad.h"
#include <cmath>
#include "caffe2/perfkernels/common.h"
namespace caffe2 {
void adagrad_update__base(
int N,
const float* w,
const float* g,
const float* h,
float* nw,
float* nh,
float epsilon,
float decay,
const float lr,
const float weight_decay = 0.f) {
internal::adagrad_update_base_inlined(
N, w, g, h, nw, nh, decay, epsilon, lr, weight_decay);
}
void adagrad_update_prefetch__base(
int N,
const float* w,
const float* /* w_n */, // prefetch ptr
const float* g,
const float* h,
const float* /* h_n */, // prefetch ptr
float* nw,
float* /* nw_n */, // prefetch ptr
float* nh,
float* /* nh_n */, // prefetch ptr
float epsilon,
float lr,
float weight_decay = 0.f) {
adagrad_update__base(N, w, g, h, nw, nh, epsilon, 1.0f, lr, weight_decay);
}
void adagrad_fp16_update_prefetch__base(
int N,
const at::Half* w,
const at::Half* /* w_n */, // prefetch ptr
const float* g,
const at::Half* h,
const at::Half* /* h_n */, // prefetch ptr
at::Half* nw,
at::Half* /* nw_n */, // prefetch ptr
at::Half* nh,
at::Half* /* nh_n */, // prefetch ptr
float epsilon,
float lr,
float weight_decay = 0.f) {
internal::adagrad_update_base_inlined(
N, w, g, h, nw, nh, 1.0f, epsilon, lr, weight_decay);
}
// version without prefetching
decltype(adagrad_update__base) adagrad_update__avx2_fma;
void adagrad_update(
int N,
const float* w,
const float* g,
const float* h,
float* nw,
float* nh,
float epsilon,
float decay,
float lr,
float weight_decay) {
AVX2_FMA_DO(
adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay);
BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay);
}
decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx2_fma;
void adagrad_update_prefetch(
int N,
const float* w,
const float* w_n, // prefetch ptr
const float* g,
const float* h,
const float* h_n, // prefetch ptr
float* nw,
float* nw_n, // prefetch ptr
float* nh,
float* nh_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay) {
AVX2_FMA_DO(
adagrad_update_prefetch,
N,
w,
w_n,
g,
h,
h_n,
nw,
nw_n,
nh,
nh_n,
epsilon,
lr,
weight_decay);
BASE_DO(
adagrad_update_prefetch,
N,
w,
w_n,
g,
h,
h_n,
nw,
nw_n,
nh,
nh_n,
epsilon,
lr,
weight_decay);
}
// Version with prefetching for embeddings and
// momentum using fp16
decltype(
adagrad_fp16_update_prefetch__base) adagrad_fp16_update_prefetch__avx2_fma;
void adagrad_fp16_update_prefetch(
int N,
const at::Half* w,
const at::Half* w_n, // prefetch ptr
const float* g,
const at::Half* h,
const at::Half* h_n, // prefetch ptr
at::Half* nw,
at::Half* nw_n, // prefetch ptr
at::Half* nh,
at::Half* nh_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay) {
AVX2_FMA_DO(
adagrad_fp16_update_prefetch,
N,
w,
w_n,
g,
h,
h_n,
nw,
nw_n,
nh,
nh_n,
epsilon,
lr,
weight_decay);
BASE_DO(
adagrad_fp16_update_prefetch,
N,
w,
w_n,
g,
h,
h_n,
nw,
nw_n,
nh,
nh_n,
epsilon,
lr,
weight_decay);
}
} // namespace caffe2