pytorch/caffe2/operators/sparse_normalize_op.h
Frank Jiang d7c6debc14 Remove gradient value as input from SparseNormalize op (#24357)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24357

SparseNormalize does not need to know the gradient value to the lookup table, only the indices of the embeddings that need to be updated. By removing this input, we allow SparseNormalize to be used alongside SparseAdagradFusion

Differential Revision: D16809919

fbshipit-source-id: cc19692ba4dea8854663ae1ed8cf9365e90c99bc
2019-08-19 14:47:09 -07:00

34 lines
835 B
C++

#pragma once
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class CAFFE2_API SparseNormalizeOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit SparseNormalizeOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
use_max_norm_(
this->template GetSingleArgument<bool>("use_max_norm", true)),
norm_(this->template GetSingleArgument<float>("norm", 1.0)) {
CAFFE_ENFORCE_GE(norm_, 0, "norm should be bigger than 0");
}
bool RunOnDevice() override;
template <typename SIndex>
bool DoRunWithType();
protected:
bool use_max_norm_;
float norm_;
INPUT_TAGS(PARAM, INDICES);
OUTPUT_TAGS(OUTPUT_PARAM);
};
} // namespace caffe2