Fix compilation error (#17860)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17860

att

Reviewed By: bddppq

Differential Revision: D14402751

fbshipit-source-id: 2d53b230dfd775372addeab1d3eaf0b9552fae9f
This commit is contained in:
Yinghai Lu 2019-03-11 10:20:11 -07:00 committed by Facebook Github Bot
parent b3c9090736
commit abd39d5a88
2 changed files with 22 additions and 26 deletions

View File

@ -1,33 +1,9 @@
#include "caffe2/operators/normalize_op.h"
#include "caffe2/core/tensor.h"
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
template <typename T, class Context>
void NormalizeOp<T, Context>::DoNormalize(
const T* xData,
T* yData,
const int m,
const int n,
const int sf) {
using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
using StridedVec =
Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
using ConstStridedVec =
Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
for (int i = 0; i < n; ++i) {
auto base = (i / sf) * sf * m + (i % sf);
ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
auto norm = xVec.template lpNorm<2>();
norm = std::max(norm, kEps_);
StridedVec yVec(yData + base, 1, m, InnerStride(sf));
yVec = xVec / norm;
}
};
template <typename T, class Context>
void NormalizeGradientOp<T, Context>::DoNormalize(
const T* xData,

View File

@ -3,6 +3,7 @@
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
#define KEPS 1e-12f
@ -35,8 +36,27 @@ class NormalizeOp final : public Operator<Context> {
private:
const T kEps_ = KEPS;
void
DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
void DoNormalize(
const T* xData,
T* yData,
const int m,
const int n,
const int sf) {
using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
using StridedVec =
Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
using ConstStridedVec =
Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
for (int i = 0; i < n; ++i) {
auto base = (i / sf) * sf * m + (i % sf);
ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
auto norm = xVec.template lpNorm<2>();
norm = std::max(norm, kEps_);
StridedVec yVec(yData + base, 1, m, InnerStride(sf));
yVec = xVec / norm;
}
}
};
template <typename T, class Context>