#include "caffe2/operators/minmax_ops.h" #include #include #include "caffe2/utils/eigen_utils.h" namespace caffe2 { template bool SelectGradientOpBase::RunOnDevice() { const auto& Y = Input(0); const auto& dY = Input(1); const int N = Y.numel(); ConstEigenVectorArrayMap Y_arr(Y.template data(), N); ConstEigenVectorArrayMap dY_arr(dY.template data(), N); for (int i = 0; i < OutputSize(); i++) { const auto& Xi = Input(i + 2); auto* dXi = Output(i, Xi.sizes(), at::dtype()); ConstEigenVectorArrayMap Xi_arr(Xi.template data(), N); EigenVectorArrayMap dXi_arr(dXi->template mutable_data(), N); dXi_arr = (Xi_arr == Y_arr).template cast() * dY_arr; } return true; } REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp); REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp); OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); namespace { class GetMaxGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; std::vector GetGradientDefs() override { std::vector inputs = {O(0), GO(0)}; std::vector grad_inputs; for (int i = 0; i < def_.input_size(); ++i) { inputs.push_back(I(i)); grad_inputs.push_back(GI(i)); } return SingleGradientDef("MaxGradient", "", inputs, grad_inputs); } }; class GetMinGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector GetGradientDefs() override { std::vector inputs = {O(0), GO(0)}; std::vector grad_inputs; for (int i = 0; i < def_.input_size(); ++i) { inputs.push_back(I(i)); grad_inputs.push_back(GI(i)); } return SingleGradientDef("MinGradient", "", inputs, grad_inputs); } }; } // namespace REGISTER_GRADIENT(Max, GetMaxGradient); REGISTER_GRADIENT(Min, GetMinGradient); } // namespace caffe2