mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit 716b3b893d.
Reverted https://github.com/pytorch/pytorch/pull/103725 on behalf of https://github.com/osalpekar due to Broke caffe2 builds due. More info at [D46920675](https://www.internalfb.com/diff/D46920675) ([comment](https://github.com/pytorch/pytorch/pull/103725#issuecomment-1603129273))
185 lines
3.5 KiB
C++
185 lines
3.5 KiB
C++
#include "caffe2/perfkernels/adagrad.h"
|
|
|
|
#include <cmath>
|
|
|
|
#include "caffe2/perfkernels/common.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
void adagrad_update__base(
|
|
int N,
|
|
const float* w,
|
|
const float* g,
|
|
const float* h,
|
|
float* nw,
|
|
float* nh,
|
|
float epsilon,
|
|
float decay,
|
|
const float lr,
|
|
const float weight_decay = 0.f) {
|
|
internal::adagrad_update_base_inlined(
|
|
N, w, g, h, nw, nh, decay, epsilon, lr, weight_decay);
|
|
}
|
|
|
|
void adagrad_update_prefetch__base(
|
|
int N,
|
|
const float* w,
|
|
const float* /* w_n */, // prefetch ptr
|
|
|
|
const float* g,
|
|
|
|
const float* h,
|
|
const float* /* h_n */, // prefetch ptr
|
|
|
|
float* nw,
|
|
float* /* nw_n */, // prefetch ptr
|
|
|
|
float* nh,
|
|
float* /* nh_n */, // prefetch ptr
|
|
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay = 0.f) {
|
|
adagrad_update__base(N, w, g, h, nw, nh, epsilon, 1.0f, lr, weight_decay);
|
|
}
|
|
|
|
void adagrad_fp16_update_prefetch__base(
|
|
int N,
|
|
const at::Half* w,
|
|
const at::Half* /* w_n */, // prefetch ptr
|
|
const float* g,
|
|
const at::Half* h,
|
|
const at::Half* /* h_n */, // prefetch ptr
|
|
at::Half* nw,
|
|
at::Half* /* nw_n */, // prefetch ptr
|
|
at::Half* nh,
|
|
at::Half* /* nh_n */, // prefetch ptr
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay = 0.f) {
|
|
internal::adagrad_update_base_inlined(
|
|
N, w, g, h, nw, nh, 1.0f, epsilon, lr, weight_decay);
|
|
}
|
|
|
|
// version without prefetching
|
|
decltype(adagrad_update__base) adagrad_update__avx2_fma;
|
|
void adagrad_update(
|
|
int N,
|
|
const float* w,
|
|
const float* g,
|
|
const float* h,
|
|
float* nw,
|
|
float* nh,
|
|
float epsilon,
|
|
float decay,
|
|
float lr,
|
|
float weight_decay) {
|
|
AVX2_FMA_DO(
|
|
adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay);
|
|
BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay);
|
|
}
|
|
|
|
decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx2_fma;
|
|
void adagrad_update_prefetch(
|
|
int N,
|
|
const float* w,
|
|
const float* w_n, // prefetch ptr
|
|
|
|
const float* g,
|
|
|
|
const float* h,
|
|
const float* h_n, // prefetch ptr
|
|
|
|
float* nw,
|
|
float* nw_n, // prefetch ptr
|
|
|
|
float* nh,
|
|
float* nh_n, // prefetch ptr
|
|
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay) {
|
|
AVX2_FMA_DO(
|
|
adagrad_update_prefetch,
|
|
N,
|
|
w,
|
|
w_n,
|
|
g,
|
|
h,
|
|
h_n,
|
|
nw,
|
|
nw_n,
|
|
nh,
|
|
nh_n,
|
|
epsilon,
|
|
lr,
|
|
weight_decay);
|
|
BASE_DO(
|
|
adagrad_update_prefetch,
|
|
N,
|
|
w,
|
|
w_n,
|
|
g,
|
|
h,
|
|
h_n,
|
|
nw,
|
|
nw_n,
|
|
nh,
|
|
nh_n,
|
|
epsilon,
|
|
lr,
|
|
weight_decay);
|
|
}
|
|
|
|
// Version with prefetching for embeddings and
|
|
// momentum using fp16
|
|
decltype(adagrad_fp16_update_prefetch__base)
|
|
adagrad_fp16_update_prefetch__avx2_fma;
|
|
void adagrad_fp16_update_prefetch(
|
|
int N,
|
|
const at::Half* w,
|
|
const at::Half* w_n, // prefetch ptr
|
|
const float* g,
|
|
const at::Half* h,
|
|
const at::Half* h_n, // prefetch ptr
|
|
at::Half* nw,
|
|
at::Half* nw_n, // prefetch ptr
|
|
at::Half* nh,
|
|
at::Half* nh_n, // prefetch ptr
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay) {
|
|
AVX2_FMA_DO(
|
|
adagrad_fp16_update_prefetch,
|
|
N,
|
|
w,
|
|
w_n,
|
|
g,
|
|
h,
|
|
h_n,
|
|
nw,
|
|
nw_n,
|
|
nh,
|
|
nh_n,
|
|
epsilon,
|
|
lr,
|
|
weight_decay);
|
|
BASE_DO(
|
|
adagrad_fp16_update_prefetch,
|
|
N,
|
|
w,
|
|
w_n,
|
|
g,
|
|
h,
|
|
h_n,
|
|
nw,
|
|
nw_n,
|
|
nh,
|
|
nh_n,
|
|
epsilon,
|
|
lr,
|
|
weight_decay);
|
|
}
|
|
|
|
} // namespace caffe2
|