mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37705 Pull Request resolved: https://github.com/pytorch/pytorch/pull/37372 Posted note: [Regularizing SparseNN Against Over-fitting](https://fb.workplace.com/notes/taiqing-wang/regularizing-sparsenn-against-over-fitting/220306075902708/) **Problem formulation** L(w) = J(w) + lambda/2 * ||w||^2 J(w) is the empirical loss, and ||w||^2 is the squared L2 norm of the parameters, a.k.a. L2 regularizer. dL(w)/ dw_i = dJ(w)/dw_i + lambda w_i dL(w)/ dw_i is the gradient of L(w) w.r.t. w_i. To implement the L2 regularizer, the gradient of J(w) w.r.t. w_i is added with w_i. lambda is called as weight decay in this implementation. **Code changes** * In the initialization method of AdagradOptimizer, a new input argument, weight_decay, is added. * In the _run function of AdagradOptimizer, the weight decay will be skipped for 1d bias vectors. * In the parameter update functions of Adagrad, the gradient is updated by weight_decay * w_i. The default value for weight_decay is zero. Test Plan: ` buck build caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test_weight_decay ` ` ./buck-out/gen/caffe2/caffe2/fb/dper/layer_models/tests/split_1/sparse_nn_test_weight_decay#binary.par ` Reviewed By: jspark1105 Differential Revision: D21258652 fbshipit-source-id: d2366ddcd736a03205a2d16f914703b16d9fce8f
126 lines
3.8 KiB
C++
126 lines
3.8 KiB
C++
#include "caffe2/perfkernels/adagrad.h"
|
|
#include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
|
|
|
|
#include <emmintrin.h>
|
|
#include <immintrin.h>
|
|
|
|
namespace caffe2 {
|
|
|
|
// version without prefetching
|
|
void adagrad_update__avx2_fma(
|
|
int N,
|
|
const float* w,
|
|
const float* g,
|
|
const float* h,
|
|
float* nw,
|
|
float* nh,
|
|
float epsilon,
|
|
float decay,
|
|
float lr,
|
|
float weight_decay = 0.f) {
|
|
constexpr size_t kSize = 8;
|
|
auto i = 0;
|
|
for (; i + kSize <= N; i += kSize) {
|
|
__m256 gi = _mm256_loadu_ps(g + i);
|
|
__m256 hi = _mm256_loadu_ps(h + i);
|
|
__m256 wi = _mm256_loadu_ps(w + i);
|
|
gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi);
|
|
|
|
__m256 nhi = _mm256_add_ps(
|
|
_mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi));
|
|
_mm256_storeu_ps(nh + i, nhi);
|
|
__m256 vtmp = _mm256_div_ps(
|
|
_mm256_mul_ps(_mm256_set1_ps(lr), gi),
|
|
_mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
|
|
_mm256_storeu_ps(nw + i, _mm256_add_ps(wi, vtmp));
|
|
}
|
|
|
|
for (; i < N; ++i) {
|
|
float gi = std::fma(weight_decay, w[i], g[i]);
|
|
float hi = nh[i] = decay * h[i] + gi * gi;
|
|
nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
|
|
}
|
|
}
|
|
|
|
void adagrad_update_prefetch__avx2_fma(
|
|
int N,
|
|
const float* w,
|
|
const float* w_n, // prefetch ptr
|
|
|
|
const float* g,
|
|
|
|
const float* h,
|
|
const float* h_n, // prefetch ptr
|
|
|
|
float* nw,
|
|
float* nw_n, // prefetch ptr
|
|
|
|
float* nh,
|
|
float* nh_n, // prefetch ptr
|
|
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay = 0.f) {
|
|
internal::adagrad_update_prefetch_inlined(
|
|
N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr, weight_decay);
|
|
}
|
|
|
|
// Compute adagrad sparse, assumes embedding and momentum are at::Half
|
|
void adagrad_fp16_update_prefetch__avx2_fma(
|
|
int N,
|
|
const at::Half* w,
|
|
const at::Half* w_n, // prefetch ptr
|
|
const float* g,
|
|
const at::Half* h,
|
|
const at::Half* h_n, // prefetch ptr
|
|
at::Half* nw,
|
|
at::Half* nw_n, // prefetch ptr
|
|
at::Half* nh,
|
|
at::Half* nh_n, // prefetch ptr
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay = 0.f) {
|
|
constexpr int kSize = 8;
|
|
auto i = 0;
|
|
for (; i + kSize <= N; i += kSize) {
|
|
_mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
|
|
_mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
|
|
_mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
|
|
_mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
|
|
|
|
// only convert momentum and embedding, gradient is fp32
|
|
__m256 gi = _mm256_loadu_ps(g + i);
|
|
__m128i hhi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(h + i));
|
|
__m256 hi = _mm256_cvtph_ps(hhi);
|
|
__m128i whi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(w + i));
|
|
__m256 wi = _mm256_cvtph_ps(whi);
|
|
gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi);
|
|
|
|
__m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
|
|
__m128i nhhi = _mm256_cvtps_ph(nhi, 0);
|
|
_mm_storeu_si128(reinterpret_cast<__m128i*>(nh + i), nhhi);
|
|
|
|
__m256 vtmp = _mm256_div_ps(
|
|
_mm256_mul_ps(_mm256_set1_ps(lr), gi),
|
|
_mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
|
|
__m256 nwi = _mm256_add_ps(wi, vtmp);
|
|
__m128i nhwi = _mm256_cvtps_ph(nwi, 0);
|
|
_mm_storeu_si128(reinterpret_cast<__m128i*>(nw + i), nhwi);
|
|
}
|
|
|
|
for (; i < N; ++i) {
|
|
float gi = std::fma(
|
|
weight_decay,
|
|
_cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]),
|
|
g[i]);
|
|
float nhi =
|
|
_cvtsh_ss(reinterpret_cast<const unsigned short*>(h)[i]) + gi * gi;
|
|
reinterpret_cast<unsigned short*>(nh)[i] = _cvtss_sh(nhi, 0);
|
|
float nwi = _cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]) +
|
|
lr * gi / (std::sqrt(nhi) + epsilon);
|
|
reinterpret_cast<unsigned short*>(nw)[i] = _cvtss_sh(nwi, 0);
|
|
}
|
|
}
|
|
|
|
} // namespace caffe2
|