mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add CPU version of hard sigmoid operator to caffe2 (#10837)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10837 Add CPU version of hard sigmoid operator to caffe2. The definition of this operator can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#HardSigmoid. Reviewed By: BIT-silence Differential Revision: D9489536 fbshipit-source-id: 67b3171ed96d5ebcc8d500d93e7827a4a9705a81
This commit is contained in:
parent
efd2aeac9e
commit
92ff070b83
154
caffe2/operators/hard_sigmoid_op.cc
Normal file
154
caffe2/operators/hard_sigmoid_op.cc
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
#include "caffe2/operators/hard_sigmoid_op.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "caffe2/utils/eigen_utils.h"
|
||||||
|
|
||||||
|
namespace caffe2 {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
template <typename T>
|
||||||
|
bool HardSigmoidFunctor<CPUContext>::
|
||||||
|
operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const {
|
||||||
|
EigenVectorArrayMap<T>(Y, N) =
|
||||||
|
(ConstEigenVectorArrayMap<T>(X, N) * T(alpha) + T(beta))
|
||||||
|
.cwiseMin(T(1))
|
||||||
|
.cwiseMax(T(0));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
template <typename T>
|
||||||
|
bool HardSigmoidGradientFunctor<CPUContext>::Forward(
|
||||||
|
const std::vector<int>& Y_dims,
|
||||||
|
const std::vector<int>& /* dY_dims */,
|
||||||
|
const T* Y,
|
||||||
|
const T* dY,
|
||||||
|
T* dX,
|
||||||
|
CPUContext* /* context */) const {
|
||||||
|
const int size = std::accumulate(
|
||||||
|
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
|
||||||
|
ConstEigenVectorArrayMap<T> Y_arr(Y, size);
|
||||||
|
EigenVectorArrayMap<T>(dX, size) =
|
||||||
|
(Y_arr > T(0) && Y_arr < T(1))
|
||||||
|
.select(ConstEigenVectorArrayMap<T>(dY, size) * alpha, T(0));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
OpSchema::Cost CostInferenceForHardSigmoid(
|
||||||
|
const OperatorDef& def,
|
||||||
|
const vector<TensorShape>& in) {
|
||||||
|
struct OpSchema::Cost cost = PointwiseCostInference<4>(def, in);
|
||||||
|
cost.params_bytes = 0;
|
||||||
|
return cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
REGISTER_CPU_OPERATOR(
|
||||||
|
HardSigmoid,
|
||||||
|
UnaryElementwiseWithArgsOp<
|
||||||
|
TensorTypes<float>,
|
||||||
|
CPUContext,
|
||||||
|
HardSigmoidFunctor<CPUContext>>);
|
||||||
|
REGISTER_CPU_OPERATOR(
|
||||||
|
HardSigmoidGradient,
|
||||||
|
BinaryElementwiseWithArgsOp<
|
||||||
|
TensorTypes<float>,
|
||||||
|
CPUContext,
|
||||||
|
HardSigmoidGradientFunctor<CPUContext>>);
|
||||||
|
|
||||||
|
// Input: X, output: Y
|
||||||
|
OPERATOR_SCHEMA(HardSigmoid)
|
||||||
|
.NumInputs(1)
|
||||||
|
.NumOutputs(1)
|
||||||
|
.AllowInplace({{0, 0}})
|
||||||
|
.CostInferenceFunction(CostInferenceForHardSigmoid)
|
||||||
|
.IdenticalTypeAndShape()
|
||||||
|
.SetDoc(R"DOC(
|
||||||
|
Applies hard sigmoid operation to the input data element-wise.
|
||||||
|
The HardSigmoid operation takes one input $X$, produces one output $Y$, and is defined as:
|
||||||
|
|
||||||
|
$$Y = max(0,min(1,x * alpha + beta))$$
|
||||||
|
|
||||||
|
Github Links:
|
||||||
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/hard_sigmoid_op.h
|
||||||
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/hard_sigmoid_op.cc
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary> <b>Example</b> </summary>
|
||||||
|
|
||||||
|
**Code**
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
workspace.ResetWorkspace()
|
||||||
|
|
||||||
|
op = core.CreateOperator(
|
||||||
|
"HardSigmoid",
|
||||||
|
["X"],
|
||||||
|
["Y"],
|
||||||
|
alpha = 0.2,
|
||||||
|
beta = 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace.FeedBlob("X", np.random.randn(5).astype(np.float32))
|
||||||
|
print("input:", workspace.FetchBlob("X"))
|
||||||
|
workspace.RunOperatorOnce(op)
|
||||||
|
print("sigmoid:", workspace.FetchBlob("Y"))
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**Result**
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
input: [ 1.5744036 0.31632107 1.7842269 1.4450722 -2.1726978 ]
|
||||||
|
hard_sigmoid: [ 0.81488073, 0.56326419, 0.85684538, 0.78901446, 0.06546044]
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
)DOC")
|
||||||
|
.Arg("alpha", "float: the slope of the function. Defaults to 0.2")
|
||||||
|
.Arg("beta", "float: the bias value of the function. Defaults to 0.5")
|
||||||
|
.Input(0, "X", "1D input tensor")
|
||||||
|
.Output(0, "Y", "1D output tensor with same shape as input")
|
||||||
|
.InheritOnnxSchema("HardSigmoid");
|
||||||
|
|
||||||
|
// Input: Y, dY, output: dX
|
||||||
|
OPERATOR_SCHEMA(HardSigmoidGradient)
|
||||||
|
.NumInputs(2)
|
||||||
|
.NumOutputs(1)
|
||||||
|
.AllowInplace({{1, 0}})
|
||||||
|
.SetDoc(R"DOC(
|
||||||
|
HardSigmoidGradient takes both Y and dY as well as an argument alpha and uses
|
||||||
|
this to update dX according to the chain rule and derivatives of the hard
|
||||||
|
sigmoid function.
|
||||||
|
)DOC");
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class GetHardSigmoidGradient : public GradientMakerBase {
|
||||||
|
using GradientMakerBase::GradientMakerBase;
|
||||||
|
std::vector<OperatorDef> GetGradientDefs() override {
|
||||||
|
return SingleGradientDef(
|
||||||
|
def_.type() + "Gradient",
|
||||||
|
"",
|
||||||
|
std::vector<std::string>{O(0), GO(0)},
|
||||||
|
std::vector<std::string>{GI(0)});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
REGISTER_GRADIENT(HardSigmoid, GetHardSigmoidGradient);
|
||||||
|
|
||||||
|
} // namespace caffe2
|
||||||
41
caffe2/operators/hard_sigmoid_op.h
Normal file
41
caffe2/operators/hard_sigmoid_op.h
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
#ifndef CAFFE2_OPERATORS_HARD_SIGMOID_H_
|
||||||
|
#define CAFFE2_OPERATORS_HARD_SIGMOID_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "caffe2/operators/elementwise_ops.h"
|
||||||
|
|
||||||
|
namespace caffe2 {
|
||||||
|
|
||||||
|
template <class Context>
|
||||||
|
struct HardSigmoidFunctor {
|
||||||
|
explicit HardSigmoidFunctor(OperatorBase& op)
|
||||||
|
: alpha(op.GetSingleArgument<float>("alpha", 0.2f)),
|
||||||
|
beta(op.GetSingleArgument<float>("beta", 0.5f)) {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const int N, const T* X, T* Y, Context* context) const;
|
||||||
|
|
||||||
|
const float alpha, beta;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Context>
|
||||||
|
struct HardSigmoidGradientFunctor {
|
||||||
|
explicit HardSigmoidGradientFunctor(OperatorBase& op)
|
||||||
|
: alpha(op.GetSingleArgument<float>("alpha", 0.2f)) {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool Forward(
|
||||||
|
const std::vector<int>& Y_dims,
|
||||||
|
const std::vector<int>& dY_dims,
|
||||||
|
const T* Y,
|
||||||
|
const T* dY,
|
||||||
|
T* dX,
|
||||||
|
Context* context) const;
|
||||||
|
|
||||||
|
const float alpha;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace caffe2
|
||||||
|
|
||||||
|
#endif // CAFFE2CAFFE2_OPERATORS_HARD_SIGMOID_H_
|
||||||
|
|
@ -4,7 +4,7 @@ from __future__ import print_function
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from caffe2.python import core, workspace
|
from caffe2.python import core, workspace
|
||||||
from hypothesis import given
|
from hypothesis import given, assume
|
||||||
import caffe2.python.hypothesis_test_util as hu
|
import caffe2.python.hypothesis_test_util as hu
|
||||||
import hypothesis.strategies as st
|
import hypothesis.strategies as st
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -333,6 +333,46 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||||
self.assertDeviceChecks(dc, op, [X], [0])
|
self.assertDeviceChecks(dc, op, [X], [0])
|
||||||
self.assertGradientChecks(gc, op, [X], 0, [0])
|
self.assertGradientChecks(gc, op, [X], 0, [0])
|
||||||
|
|
||||||
|
@given(X=hu.tensor(dtype=np.float32),
|
||||||
|
inplace=st.booleans(),
|
||||||
|
alpha=st.floats(min_value=-100.0, max_value=100.0),
|
||||||
|
beta=st.floats(min_value=-100.0, max_value=100.0),
|
||||||
|
engine=st.sampled_from([""]),
|
||||||
|
**hu.gcs_cpu_only)
|
||||||
|
def test_hard_sigmoid(self, X, inplace, alpha, beta, engine, gc, dc):
|
||||||
|
# Prevent alpha and beta from mutually being 0 to avoid a division
|
||||||
|
# error when adjusting our inputs
|
||||||
|
assume(alpha != 0.0 or beta != 0.0)
|
||||||
|
op = core.CreateOperator(
|
||||||
|
"HardSigmoid",
|
||||||
|
["X"],
|
||||||
|
["X"] if inplace else ["Y"],
|
||||||
|
alpha=alpha,
|
||||||
|
beta=beta,
|
||||||
|
engine=engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
def hard_sigmoid_ref(X):
|
||||||
|
return [np.minimum(1.0, np.maximum(0.0, X * alpha + beta))]
|
||||||
|
|
||||||
|
# Adjust inputs to avoid differentitating at inflection points
|
||||||
|
if abs(alpha) > 0.001:
|
||||||
|
Y = X * alpha + beta
|
||||||
|
Y += 0.04 * np.sign(Y)
|
||||||
|
Y[Y == 0.0] += 0.1
|
||||||
|
Y[Y == 1.0] -= 0.1
|
||||||
|
X = (Y - beta) / alpha
|
||||||
|
|
||||||
|
self.assertReferenceChecks(
|
||||||
|
device_option=gc,
|
||||||
|
op=op,
|
||||||
|
inputs=[X],
|
||||||
|
reference=hard_sigmoid_ref,
|
||||||
|
)
|
||||||
|
self.assertDeviceChecks(dc, op, [X], [0])
|
||||||
|
self.assertGradientChecks(
|
||||||
|
gc, op, [X], 0, [0], stepsize=1e-4, threshold=1e-2)
|
||||||
|
|
||||||
@given(n=st.integers(0, 6), m=st.integers(4, 6), **hu.gcs)
|
@given(n=st.integers(0, 6), m=st.integers(4, 6), **hu.gcs)
|
||||||
def test_eq(self, n, m, gc, dc):
|
def test_eq(self, n, m, gc, dc):
|
||||||
# Set broadcast and no axis, i.e. broadcasting last dimensions.
|
# Set broadcast and no axis, i.e. broadcasting last dimensions.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user