mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24357 SparseNormalize does not need to know the gradient value to the lookup table, only the indices of the embeddings that need to be updated. By removing this input, we allow SparseNormalize to be used alongside SparseAdagradFusion Differential Revision: D16809919 fbshipit-source-id: cc19692ba4dea8854663ae1ed8cf9365e90c99bc
34 lines
835 B
C++
34 lines
835 B
C++
#pragma once
|
|
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/utils/math.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <typename T, class Context>
|
|
class CAFFE2_API SparseNormalizeOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
|
explicit SparseNormalizeOp(Args&&... args)
|
|
: Operator<Context>(std::forward<Args>(args)...),
|
|
use_max_norm_(
|
|
this->template GetSingleArgument<bool>("use_max_norm", true)),
|
|
norm_(this->template GetSingleArgument<float>("norm", 1.0)) {
|
|
CAFFE_ENFORCE_GE(norm_, 0, "norm should be bigger than 0");
|
|
}
|
|
|
|
bool RunOnDevice() override;
|
|
|
|
template <typename SIndex>
|
|
bool DoRunWithType();
|
|
|
|
protected:
|
|
bool use_max_norm_;
|
|
float norm_;
|
|
INPUT_TAGS(PARAM, INDICES);
|
|
OUTPUT_TAGS(OUTPUT_PARAM);
|
|
};
|
|
|
|
} // namespace caffe2
|