mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: * Likely need to test this so bad formatting can't be added in the future, but cleaning all operators so we at least have good examples. * Formatting between our internal Facebook operator catalog and external caffe2.ai catalog are still slightly different. We'll work on this. Closes https://github.com/caffe2/caffe2/pull/1846 Reviewed By: pjh5 Differential Revision: D6848570 Pulled By: orionr fbshipit-source-id: b9bc0bfccb243d0440bd7b2406858cad8dc37e92
27 lines
802 B
C++
27 lines
802 B
C++
#include "caffe2/operators/negate_gradient_op.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
REGISTER_CPU_OPERATOR(NegateGradient, NegateGradientOp<CPUContext>);
|
|
OPERATOR_SCHEMA(NegateGradient)
|
|
.NumInputs(1)
|
|
.NumOutputs(1)
|
|
.AllowInplace({{0, 0}})
|
|
.SetDoc(R"DOC(
|
|
NegagteGradient operator in forward pass simply copies input to the
|
|
output, and in backward pass, flips the sign of the output gradient
|
|
)DOC");
|
|
|
|
struct GetNegateGradientGradient : public GradientMakerBase {
|
|
using GradientMakerBase::GradientMakerBase;
|
|
std::vector<OperatorDef> GetGradientDefs() override {
|
|
CAFFE_ENFORCE_EQ(def_.input_size(), 1);
|
|
return SingleGradientDef(
|
|
"Negative", "", vector<string>{GO(0)}, vector<string>{GI(0)});
|
|
}
|
|
};
|
|
|
|
REGISTER_GRADIENT(NegateGradient, GetNegateGradientGradient);
|
|
|
|
} // namespace caffe2
|