mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35164 As title Test Plan: CI Reviewed By: jianyuh Differential Revision: D20581853 fbshipit-source-id: 393ddd9487cd965c465eaa49e1509863618a6048
42 lines
766 B
C++
42 lines
766 B
C++
#include "caffe2/sgd/math_lp.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace internal {
|
|
|
|
template <>
|
|
void dot<float, float, float>(
|
|
const int N,
|
|
const float* x,
|
|
const float* y,
|
|
float* z,
|
|
CPUContext* ctx) {
|
|
math::Dot<float, CPUContext>(N, x, y, z, ctx);
|
|
}
|
|
|
|
template <>
|
|
void dot<float, at::Half, float>(
|
|
const int N,
|
|
const float* x,
|
|
const at::Half* y,
|
|
float* z,
|
|
CPUContext* ctx) {
|
|
#ifdef _MSC_VER
|
|
std::vector<float> tmp_y_vec(N);
|
|
float* tmp_y = tmp_y_vec.data();
|
|
#else
|
|
float tmp_y[N];
|
|
#endif
|
|
for (int i = 0; i < N; i++) {
|
|
#ifdef __F16C__
|
|
tmp_y[i] = _cvtss_sh(y[i], 0); // TODO: vectorize
|
|
#else
|
|
tmp_y[i] = y[i];
|
|
#endif
|
|
}
|
|
math::Dot<float, CPUContext>(N, x, tmp_y, z, ctx);
|
|
}
|
|
|
|
} // namespace internal
|
|
} // namespace caffe2
|