mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Revert D6893040: Implementing Pow operator (this merges existing pow with a scalar and new pow with a tensor exponent).
Summary: This reverts commit 30f614beea6f859fee25ce4f85573142885dde45 bypass-lint An infra SEV is better than not reverting this diff. If you copy this password, see you in SEV Review! cause_a_sev_many_files Differential Revision: D6893040 Original commit changeset: 30f614beea6f fbshipit-source-id: 5e98a24699088283f864efe31234874bdacbe3c3
This commit is contained in:
parent
c746357017
commit
52fa742c51
|
|
@ -113,16 +113,6 @@ CUDA_FUNCTOR(Or, CUDA_OR, BoolTypes, FixedType<bool>);
|
|||
CUDA_FUNCTOR(Xor, CUDA_XOR, BoolTypes, FixedType<bool>);
|
||||
#undef CUDA_XOR
|
||||
|
||||
// pow, log and other math functions are defined in CUDA math library
|
||||
// in header file math.h
|
||||
#define CUDA_POW(x, y) (pow(x, y))
|
||||
CUDA_FUNCTOR(
|
||||
Pow,
|
||||
CUDA_POW,
|
||||
TensorTypes<float> /*NumericTypes*/,
|
||||
SameTypeAsInput);
|
||||
#undef CUDA_POW
|
||||
|
||||
__global__ void NotKernel(const int n, const bool* x, bool* y) {
|
||||
CUDA_1D_KERNEL_LOOP(i, n) {
|
||||
y[i] = !x[i];
|
||||
|
|
|
|||
|
|
@ -82,4 +82,65 @@ OPERATOR_SCHEMA(Sign)
|
|||
.IdenticalTypeAndShape();
|
||||
SHOULD_NOT_DO_GRADIENT(Sign);
|
||||
|
||||
REGISTER_CPU_OPERATOR(
|
||||
Pow,
|
||||
UnaryElementwiseWithArgsOp<TensorTypes<float>, CPUContext, PowFunctor>);
|
||||
|
||||
OPERATOR_SCHEMA(Pow)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Arg("exponent", "The exponent of the power function.")
|
||||
.AllowInplace({{0, 0}})
|
||||
.IdenticalTypeAndShape()
|
||||
.SetDoc(R"DOC(
|
||||
Pow takes input data (Tensor<T>) and an argument exponent, and
|
||||
produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
|
||||
is applied to the data tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor of any shape")
|
||||
.Output(0, "Y", "Output tensor (same size as X)");
|
||||
|
||||
class GetPowGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
ArgumentHelper arg_helper(def_);
|
||||
float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
|
||||
Argument scale_arg;
|
||||
scale_arg.set_name("scale");
|
||||
scale_arg.set_f(exponent);
|
||||
Argument pow_arg;
|
||||
pow_arg.set_name("exponent");
|
||||
if (I(0) != O(0)) {
|
||||
pow_arg.set_f(exponent - 1);
|
||||
} else {
|
||||
LOG(WARNING) << "In-place Pow gradient, possible loss of precision";
|
||||
constexpr float kEps = 1e-12f;
|
||||
CAFFE_ENFORCE(std::fabs(exponent) > kEps);
|
||||
pow_arg.set_f((exponent - 1) / exponent);
|
||||
}
|
||||
return vector<OperatorDef>{CreateOperatorDef(
|
||||
"Pow",
|
||||
"",
|
||||
std::vector<string>{I(0)},
|
||||
std::vector<string>{GI(0)},
|
||||
std::vector<Argument>{pow_arg}),
|
||||
CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(0), GO(0)},
|
||||
std::vector<string>{GI(0)}),
|
||||
CreateOperatorDef(
|
||||
"Scale",
|
||||
"",
|
||||
std::vector<string>{GI(0)},
|
||||
std::vector<string>{GI(0)},
|
||||
std::vector<Argument>{scale_arg})};
|
||||
}
|
||||
virtual bool CopyArguments() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_GRADIENT(Pow, GetPowGradient);
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -52,4 +52,7 @@ REGISTER_CUDA_OPERATOR(
|
|||
REGISTER_CUDA_OPERATOR(
|
||||
Sign,
|
||||
UnaryElementwiseOp<TensorTypes<float>, CUDAContext, SignCUDAFunctor>);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
Pow,
|
||||
UnaryElementwiseWithArgsOp<TensorTypes<float>, CUDAContext, PowFunctor>);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,4 +25,21 @@
|
|||
#include "caffe2/operators/elementwise_op.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
struct PowFunctor {
|
||||
explicit PowFunctor(OperatorBase& op) {
|
||||
exponent_ = op.GetSingleArgument<float>("exponent", 0);
|
||||
}
|
||||
|
||||
template <typename T, class Context>
|
||||
inline void
|
||||
operator()(const int n, const T* x, T* y, Context* device_context) {
|
||||
math::Powx<float, Context>(n, x, exponent_, y, device_context);
|
||||
}
|
||||
|
||||
float exponent_;
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,323 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2018-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/operators/pow_op.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
// definition of NumericTypes and SameTypeAsInput is in below header file
|
||||
//#include "caffe2/operators/elementwise_op.h"
|
||||
#include <Eigen/Core>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
#define EIGEN_POW(x, y) (x.pow(y))
|
||||
|
||||
struct EigenPowFunctor {
|
||||
template <int b_is_scalar, typename T1, typename T2, typename R>
|
||||
inline void Run(size_t n, const T1* a, const T2* b, R* out, CPUContext*) {
|
||||
if (b_is_scalar) {
|
||||
EigenVectorArrayMap<R>(out, n) =
|
||||
EIGEN_POW((ConstEigenVectorArrayMap<T1>(a, n)), (b[0]));
|
||||
} else {
|
||||
EigenVectorArrayMap<R>(out, n) = EIGEN_POW(
|
||||
(ConstEigenVectorArrayMap<T1>(a, n)),
|
||||
(ConstEigenVectorArrayMap<T2>(b, n)));
|
||||
}
|
||||
}
|
||||
template <typename T1, typename T2, typename R>
|
||||
void RunWithBroadcast(
|
||||
const T1* a,
|
||||
const T2* b,
|
||||
R* out,
|
||||
size_t pre,
|
||||
size_t n,
|
||||
CPUContext*) {
|
||||
EigenArrayMap<R>(out, n, pre) = EIGEN_POW(
|
||||
(ConstEigenArrayMap<T1>(a, n, pre)),
|
||||
(ConstEigenVectorArrayMap<T2>(b, n)).rowwise().replicate(pre));
|
||||
/*
|
||||
//below code only allows elementary ops, such as +, -, * and /,
|
||||
//and does not allow operations, such as pow, exp and log
|
||||
EIGEN_POW(
|
||||
(ConstEigenArrayMap<T>(a, n, pre).colwise()),
|
||||
(ConstEigenVectorArrayMap<T>(b, n)));
|
||||
*/
|
||||
}
|
||||
template <typename T1, typename T2, typename R>
|
||||
void RunWithBroadcast2(
|
||||
const T1* a,
|
||||
const T2* b,
|
||||
R* out,
|
||||
size_t pre,
|
||||
size_t n,
|
||||
size_t post,
|
||||
CPUContext*) {
|
||||
for (int i = 0; i < pre; ++i) {
|
||||
EigenArrayMap<R>(out + i * n * post, post, n) = EIGEN_POW(
|
||||
(ConstEigenArrayMap<T1>(a + i * n * post, post, n)),
|
||||
(Eigen::Map<const Eigen::Array<T2, 1, Eigen::Dynamic>>(b, n))
|
||||
.colwise()
|
||||
.replicate(post));
|
||||
/*
|
||||
//below code only allows elementary ops, such as +, -, * and /,
|
||||
//and does not allow for operations, such as pow, exp and log
|
||||
EIEGN_POW(
|
||||
(ConstEigenArrayMap<T>(a + i * n * post, post, n).rowwise()),
|
||||
(Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>>(b, n)));
|
||||
*/
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CPU_OPERATOR(
|
||||
Pow,
|
||||
PowOp<
|
||||
TensorTypes<float>, /*NumericTypes,*/
|
||||
CPUContext,
|
||||
EigenPowFunctor,
|
||||
SameTypeAsInput>)
|
||||
|
||||
OPERATOR_SCHEMA(Pow)
|
||||
.NumInputs(1, 2)
|
||||
.NumOutputs(1)
|
||||
.Arg("exponent", "The exponent of the power function.")
|
||||
.AllowInplace({{0, 0}, {1, 0}})
|
||||
.SetDoc(R"DOC(
|
||||
Pow takes input data (Tensor<T>) and an argument exponent, which can be a
|
||||
scalar or another tensor. It produces one output data (Tensor<T>), where
|
||||
the function `f(x) = x^exponent` is applied to the data tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor of any shape")
|
||||
.Input(1, "exponent", "The exponent of the power function.")
|
||||
.Output(0, "Y", "Output tensor (same size as X)");
|
||||
|
||||
class GetPowGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
ArgumentHelper arg_helper(def_);
|
||||
if (arg_helper.HasArgument("exponent")) { // second input is a scalar
|
||||
// function f(w,a) = w^a
|
||||
// gradient operator with respect to first input tensor
|
||||
// df/dw = a * w^(a-1) (all operations are component-wise)
|
||||
float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
|
||||
Argument scale_arg;
|
||||
scale_arg.set_name("scale");
|
||||
scale_arg.set_f(exponent);
|
||||
Argument pow_arg;
|
||||
pow_arg.set_name("exponent");
|
||||
if (I(0) != O(0)) {
|
||||
pow_arg.set_f(exponent - 1);
|
||||
} else {
|
||||
LOG(WARNING) << "In-place Pow gradient, possible loss of precision";
|
||||
constexpr float kEps = 1e-12f;
|
||||
CAFFE_ENFORCE(std::fabs(exponent) > kEps);
|
||||
pow_arg.set_f((exponent - 1) / exponent);
|
||||
}
|
||||
return vector<OperatorDef>{CreateOperatorDef(
|
||||
"Pow",
|
||||
"",
|
||||
std::vector<string>{I(0)},
|
||||
std::vector<string>{GI(0)},
|
||||
std::vector<Argument>{pow_arg}),
|
||||
CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(0), GO(0)},
|
||||
std::vector<string>{GI(0)}),
|
||||
CreateOperatorDef(
|
||||
"Scale",
|
||||
"",
|
||||
std::vector<string>{GI(0)},
|
||||
std::vector<string>{GI(0)},
|
||||
std::vector<Argument>{scale_arg})};
|
||||
/*
|
||||
// Alternative gradient computation
|
||||
return vector<OperatorDef>{CreateOperatorDef(
|
||||
"Div",
|
||||
"",
|
||||
std::vector<string>{O(0), I(0)},
|
||||
std::vector<string>{GI(0)}),
|
||||
CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(0), GO(0)},
|
||||
std::vector<string>{GI(0)}),
|
||||
CreateOperatorDef(
|
||||
"Scale",
|
||||
"",
|
||||
std::vector<string>{GI(0)},
|
||||
std::vector<string>{GI(0)},
|
||||
std::vector<Argument>{scale_arg})};
|
||||
*/
|
||||
} else { // second input is a tensor
|
||||
CAFFE_ENFORCE(
|
||||
Def().input(0) != Def().output(0) &&
|
||||
Def().input(1) != Def().output(0),
|
||||
"Gradient computation cannot be carried out if Pow uses in-place "
|
||||
"computation: ",
|
||||
ProtoDebugString(Def()));
|
||||
vector<OperatorDef> grad_ops;
|
||||
Argument one_arg;
|
||||
one_arg.set_name("value");
|
||||
one_arg.set_f(1);
|
||||
Argument broadcast, axis, axis_str, order;
|
||||
bool bflag = ArgumentHelper::HasArgument(Def(), "broadcast");
|
||||
|
||||
if (bflag) {
|
||||
if (ArgumentHelper::HasArgument(Def(), "broadcast")) {
|
||||
broadcast = GetArgument(Def(), "broadcast");
|
||||
} else {
|
||||
broadcast = MakeArgument<int>("broadcast", 0);
|
||||
}
|
||||
if (ArgumentHelper::HasArgument(Def(), "axis")) {
|
||||
axis = GetArgument(Def(), "axis");
|
||||
} else {
|
||||
axis = MakeArgument<int>("axis", -1);
|
||||
}
|
||||
if (ArgumentHelper::HasArgument(Def(), "axis_str")) {
|
||||
axis_str = GetArgument(Def(), "axis_str");
|
||||
} else {
|
||||
axis_str = MakeArgument<string>("axis_str", "");
|
||||
}
|
||||
if (ArgumentHelper::HasArgument(Def(), "order")) {
|
||||
order = GetArgument(Def(), "order");
|
||||
} else {
|
||||
order = MakeArgument<string>("order", "NCHW");
|
||||
}
|
||||
}
|
||||
|
||||
// function f(w,a) = w^a
|
||||
// gradient operator with respect to first input tensor
|
||||
// df/dw = a * w^(a-1) (all operations are component-wise)
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"ConstantFill",
|
||||
"",
|
||||
std::vector<string>{I(1)},
|
||||
std::vector<string>{GI(1)},
|
||||
std::vector<Argument>{one_arg}));
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Sub",
|
||||
"",
|
||||
std::vector<string>{I(1), GI(1)},
|
||||
std::vector<string>{GI(1)}));
|
||||
if (bflag) {
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Pow",
|
||||
"",
|
||||
std::vector<string>{I(0), GI(1)},
|
||||
std::vector<string>{GI(0)},
|
||||
vector<Argument>{broadcast, axis, axis_str, order}));
|
||||
} else {
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Pow",
|
||||
"",
|
||||
std::vector<string>{I(0), GI(1)},
|
||||
std::vector<string>{GI(0)}));
|
||||
}
|
||||
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(0), GO(0)},
|
||||
std::vector<string>{GI(0)}));
|
||||
if (bflag) {
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(0), I(1)},
|
||||
std::vector<string>{GI(0)},
|
||||
vector<Argument>{broadcast, axis, axis_str, order}));
|
||||
} else {
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(0), I(1)},
|
||||
std::vector<string>{GI(0)}));
|
||||
}
|
||||
/*
|
||||
// Alternative gradient computation (no broadcast support)
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Div",
|
||||
"",
|
||||
std::vector<string>{O(0), I(0)},
|
||||
std::vector<string>{GI(0)}));
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(0), GO(0)},
|
||||
std::vector<string>{GI(0)}));
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(0), I(1)},
|
||||
std::vector<string>{GI(0)}));
|
||||
*/
|
||||
// gradient operator for with respect to second input tensor
|
||||
// df/da = w^a * ln w (all operations are component-wise)
|
||||
/*
|
||||
// reset GI(1) to zero
|
||||
Argument zero_arg;
|
||||
zero_arg.set_name("value");
|
||||
zero_arg.set_f(0);
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"ConstantFill",
|
||||
"",
|
||||
std::vector<string>{I(1)},
|
||||
std::vector<string>{GI(1)},
|
||||
std::vector<Argument>{zero_arg}));
|
||||
*/
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Log",
|
||||
"",
|
||||
std::vector<string>{I(0)},
|
||||
std::vector<string>{GI(1) + "_autogen_pre_red"}));
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(1) + "_autogen_pre_red", O(0)},
|
||||
std::vector<string>{GI(1) + "_autogen_pre_red"}));
|
||||
if (bflag) {
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(1) + "_autogen_pre_red", GO(0)},
|
||||
std::vector<string>{GI(1) + "_autogen_pre_red"}));
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"SumReduceLike",
|
||||
"",
|
||||
vector<string>{GI(1) + "_autogen_pre_red", I(1)},
|
||||
vector<string>{GI(1)},
|
||||
vector<Argument>{axis, axis_str, order}));
|
||||
} else {
|
||||
grad_ops.push_back(CreateOperatorDef(
|
||||
"Mul",
|
||||
"",
|
||||
std::vector<string>{GI(1) + "_autogen_pre_red", GO(0)},
|
||||
std::vector<string>{GI(1)}));
|
||||
}
|
||||
|
||||
return grad_ops;
|
||||
}
|
||||
}
|
||||
|
||||
// Argument `shape` is no longer needed in backprop.
|
||||
bool CopyArguments() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_GRADIENT(Pow, GetPowGradient);
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
@ -1,149 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2018-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CAFFE2_OPERATORS_POW_OP_H_
|
||||
#define CAFFE2_OPERATORS_POW_OP_H_
|
||||
|
||||
#include "caffe2/core/common_omp.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
// definition of NumericTypes and SameTypeAsInput is in below header file
|
||||
#include "caffe2/operators/elementwise_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <
|
||||
typename InputTypes,
|
||||
class Context,
|
||||
class Functor,
|
||||
class TypeMap = SameTypeAsInput>
|
||||
class PowOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
PowOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
|
||||
OP_SINGLE_ARG(int, "axis", axis_, -1),
|
||||
OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
|
||||
OP_SINGLE_ARG(string, "order", order_, "NCHW"),
|
||||
functor_() {
|
||||
if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
|
||||
exponent_ = this->template GetSingleArgument<float>(
|
||||
"exponent", 0); // based on pow_ops.h
|
||||
} else if (InputSize() == 2) { // BinaryElementwiseOp
|
||||
// Figure out the correct axis to use.
|
||||
if (enable_broadcast_) {
|
||||
if (axis_ != -1) {
|
||||
// Get axis from an explicit axis argument.
|
||||
CAFFE_ENFORCE_EQ(
|
||||
axis_str_.size(),
|
||||
0,
|
||||
"Args axis and axis_str cannot be used simultaneously.");
|
||||
} else if (axis_str_.size()) {
|
||||
// Get the axis index semantically.
|
||||
CAFFE_ENFORCE_EQ(
|
||||
axis_str_.size(), 1, "Unsupported axis string", axis_str_);
|
||||
size_t semantic_axis_ = order_.find(axis_str_);
|
||||
CAFFE_ENFORCE_NE(
|
||||
semantic_axis_,
|
||||
string::npos,
|
||||
"Unrecognizable axis string ",
|
||||
axis_str_,
|
||||
" from order string ",
|
||||
order_);
|
||||
axis_ = semantic_axis_;
|
||||
}
|
||||
} else {
|
||||
CAFFE_ENFORCE(
|
||||
axis_ == -1 && axis_str_.size() == 0,
|
||||
"Do not specify axis or axis_str if broadcast is not enabled.");
|
||||
}
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"Only a tensor with an argument or two input tensors are supported as input to pow operator.");
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<InputTypes>::call(this, Input(0));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DoRunWithType() {
|
||||
if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
|
||||
const auto& A = Input(0);
|
||||
auto* C = Output(0);
|
||||
C->ResizeLike(A);
|
||||
const T* Adata = A.template data<T>();
|
||||
auto* Cdata =
|
||||
C->template mutable_data<typename TypeMap::template type<T>>();
|
||||
functor_.template Run<true, T, float, T>(
|
||||
A.size(), Adata, &exponent_, Cdata, &context_);
|
||||
} else if (InputSize() == 2) { // BinaryElementwiseOp
|
||||
const auto& A = Input(0);
|
||||
const auto& B = Input(1);
|
||||
auto* C = Output(0);
|
||||
CAFFE_ENFORCE(
|
||||
&B != C || !enable_broadcast_,
|
||||
"In-place is allowed only with the first tensor when broadcasting");
|
||||
C->ResizeLike(A);
|
||||
const T* Adata = A.template data<T>();
|
||||
const T* Bdata = B.template data<T>();
|
||||
auto* Cdata =
|
||||
C->template mutable_data<typename TypeMap::template type<T>>();
|
||||
if (!enable_broadcast_) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
A.dims(),
|
||||
B.dims(),
|
||||
"Dimension mismatch - did you forget to set broadcast=1?");
|
||||
functor_.template Run<false, T, T, T>(
|
||||
A.size(), Adata, Bdata, Cdata, &context_);
|
||||
} else if (B.size() == 1) {
|
||||
functor_.template Run<true, T, T, T>(
|
||||
A.size(), Adata, Bdata, Cdata, &context_);
|
||||
} else {
|
||||
size_t pre, n, post;
|
||||
std::tie(pre, n, post) = calculate_broadcast_sizes(A, B, axis_);
|
||||
if (post == 1) {
|
||||
functor_.template RunWithBroadcast<T, T, T>(
|
||||
Adata, Bdata, Cdata, pre, n, &context_);
|
||||
} else {
|
||||
functor_.template RunWithBroadcast2<T, T, T>(
|
||||
Adata, Bdata, Cdata, pre, n, post, &context_);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"Only a tensor with an argument or two input tensors are supported as input to pow operator.");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
bool enable_broadcast_;
|
||||
int axis_;
|
||||
string axis_str_;
|
||||
string order_;
|
||||
float exponent_;
|
||||
Functor functor_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_POW_OP_H_
|
||||
|
|
@ -1107,27 +1107,6 @@ class TestOperators(hu.HypothesisTestCase):
|
|||
reference=log_ref)
|
||||
self.assertGradientChecks(gc, op, [input_tensor], 0, [0])
|
||||
|
||||
@given(input_tensors=hu.tensors(n=2, elements=st.floats(min_value=2.0, max_value=3.0, allow_nan=False, allow_infinity=False)),
|
||||
**hu.gcs_cpu_only)
|
||||
def test_powt(self, input_tensors, gc, dc):
|
||||
X1, X2 = input_tensors
|
||||
|
||||
op = core.CreateOperator(
|
||||
"Pow",
|
||||
["X1", "X2"],
|
||||
["output"]
|
||||
)
|
||||
|
||||
def powt_ref(X1, X2):
|
||||
return (np.power(X1,X2),)
|
||||
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[X1, X2],
|
||||
reference=powt_ref)
|
||||
self.assertGradientChecks(gc, op, [X1, X2], 0, [0])
|
||||
|
||||
def test_blobs_dequeue_timeout(self):
|
||||
op = core.CreateOperator(
|
||||
"CreateBlobsQueue",
|
||||
|
|
|
|||
|
|
@ -183,58 +183,6 @@ class TestElementwiseBroadcast(hu.HypothesisTestCase):
|
|||
self.assertDeviceChecks(dc, op, [X, Y], [0])
|
||||
self.assertGradientChecks(gc, op, [X, Y], 1, [0])
|
||||
|
||||
@given(**hu.gcs)
|
||||
def test_broadcast_powt(self, gc, dc):
|
||||
# Set broadcast and no axis, i.e. broadcasting last dimensions.
|
||||
X = np.random.rand(2, 3, 4, 5).astype(np.float32)
|
||||
Y = np.random.rand(4, 5).astype(np.float32) + 2.0
|
||||
|
||||
op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1)
|
||||
workspace.FeedBlob("X", X)
|
||||
workspace.FeedBlob("Y", Y)
|
||||
workspace.RunOperatorOnce(op)
|
||||
out = workspace.FetchBlob("out")
|
||||
np.testing.assert_array_almost_equal(out, np.power(X, Y))
|
||||
self.assertDeviceChecks(dc, op, [X, Y], [0])
|
||||
self.assertGradientChecks(gc, op, [X, Y], 1, [0])
|
||||
|
||||
# broadcasting intermediate dimensions
|
||||
X = np.random.rand(2, 3, 4, 5).astype(np.float32)
|
||||
Y = np.random.rand(3, 4).astype(np.float32) + 2.0
|
||||
op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=1)
|
||||
workspace.FeedBlob("X", X)
|
||||
workspace.FeedBlob("Y", Y)
|
||||
workspace.RunOperatorOnce(op)
|
||||
out = workspace.FetchBlob("out")
|
||||
np.testing.assert_array_almost_equal(out, np.power(X, Y[:, :, np.newaxis]))
|
||||
self.assertDeviceChecks(dc, op, [X, Y], [0])
|
||||
self.assertGradientChecks(gc, op, [X, Y], 1, [0])
|
||||
|
||||
# broadcasting the first dimension
|
||||
X = np.random.rand(2, 3, 4, 5).astype(np.float32)
|
||||
Y = np.random.rand(2).astype(np.float32) + 2.0
|
||||
op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=0)
|
||||
workspace.FeedBlob("X", X)
|
||||
workspace.FeedBlob("Y", Y)
|
||||
workspace.RunOperatorOnce(op)
|
||||
out = workspace.FetchBlob("out")
|
||||
np.testing.assert_array_almost_equal(
|
||||
out, np.power(X, Y[:, np.newaxis, np.newaxis, np.newaxis]))
|
||||
self.assertDeviceChecks(dc, op, [X, Y], [0])
|
||||
self.assertGradientChecks(gc, op, [X, Y], 1, [0])
|
||||
|
||||
# broadcasting with single elem dimensions at both ends
|
||||
X = np.random.rand(2, 3, 4, 5).astype(np.float32)
|
||||
Y = np.random.rand(1, 4, 1).astype(np.float32) + 2.0
|
||||
op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=1)
|
||||
workspace.FeedBlob("X", X)
|
||||
workspace.FeedBlob("Y", Y)
|
||||
workspace.RunOperatorOnce(op)
|
||||
out = workspace.FetchBlob("out")
|
||||
np.testing.assert_array_almost_equal(out, np.power(X, Y))
|
||||
self.assertDeviceChecks(dc, op, [X, Y], [0])
|
||||
self.assertGradientChecks(gc, op, [X, Y], 1, [0])
|
||||
|
||||
@given(**hu.gcs)
|
||||
def test_broadcast_scalar(self, gc, dc):
|
||||
# broadcasting constant
|
||||
|
|
|
|||
|
|
@ -75,31 +75,6 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
|||
self.assertGradientChecks(
|
||||
gc, op, [X], 0, [0], stepsize=1e-4, threshold=1e-2)
|
||||
|
||||
@given(n=st.integers(2, 10), m=st.integers(4, 6),
|
||||
d=st.integers(2, 3), **hu.gcs)
|
||||
def test_powt(self, n, m, d, gc, dc):
|
||||
X = np.random.rand(n, m, d).astype(np.float32)
|
||||
Y = np.random.rand(n, m, d).astype(np.float32) + 2.0
|
||||
|
||||
def powt_op(X, Y):
|
||||
return [np.power(X, Y)]
|
||||
|
||||
op = core.CreateOperator(
|
||||
"Pow",
|
||||
["X", "Y"],
|
||||
["Z"]
|
||||
)
|
||||
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[X, Y],
|
||||
reference=powt_op,
|
||||
)
|
||||
|
||||
self.assertGradientChecks(
|
||||
gc, op, [X, Y], 0, [0], stepsize=1e-4, threshold=1e-2)
|
||||
|
||||
@given(n=st.integers(5, 6), m=st.integers(4, 6), **hu.gcs)
|
||||
def test_sqr(self, n, m, gc, dc):
|
||||
X = np.random.rand(n, m).astype(np.float32)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user