refactor caffe2 operator constructors - 8/9 (#17089)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17089

clangr codemod

Reviewed By: ezyang

Differential Revision: D14078539

fbshipit-source-id: 9ca196af4af7f26fc82e6cf82b35d478d0597752
This commit is contained in:
Sebastian Messmer 2019-02-28 14:12:37 -08:00 committed by Facebook Github Bot
parent 28b5df1c8f
commit 7413f0926a
25 changed files with 137 additions and 93 deletions

View File

@ -202,8 +202,9 @@ template <class Context>
class SliceOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SliceOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SliceOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
statically_inited_(false) {}
@ -263,8 +264,9 @@ template <class Context>
class SliceGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SliceGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SliceGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
statically_inited_(false) {}

View File

@ -11,9 +11,10 @@ namespace caffe2 {
template <typename T, class Context>
class SoftmaxOp final : public Operator<Context> {
public:
SoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
template <class... Args>
explicit SoftmaxOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
@ -27,8 +28,9 @@ class SoftmaxOp final : public Operator<Context> {
template <typename T, class Context>
class SoftmaxGradientOp final : public Operator<Context> {
public:
SoftmaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SoftmaxGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;

View File

@ -15,8 +15,9 @@ constexpr int TOP_GRADIENT_DESC_ID = 2;
class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
public:
explicit CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<CUDAContext>(def, ws),
template <class... Args>
explicit CuDNNSoftmaxOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
cudnn_wrapper_(&context_),
axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
@ -77,8 +78,9 @@ class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
public:
explicit CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<CUDAContext>(def, ws),
template <class... Args>
explicit CuDNNSoftmaxGradientOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
cudnn_wrapper_(&context_),
axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));

View File

@ -11,10 +11,12 @@ namespace caffe2 {
template <typename T, class Context>
class SoftmaxWithLossOp final : public Operator<Context> {
public:
SoftmaxWithLossOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SoftmaxWithLossOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
scale_(this->template GetSingleArgument<float>("scale", 1.)),
label_prob_mode_(this->template GetSingleArgument<int>("label_prob", 0)),
label_prob_mode_(
this->template GetSingleArgument<int>("label_prob", 0)),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),
axis_(this->template GetSingleArgument<int>("axis", 1)) {
@ -44,10 +46,12 @@ class SoftmaxWithLossOp final : public Operator<Context> {
template <typename T, class Context>
class SoftmaxWithLossGradientOp final : public Operator<Context> {
public:
SoftmaxWithLossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
template <class... Args>
explicit SoftmaxWithLossGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
scale_(this->template GetSingleArgument<float>("scale", 1.)),
label_prob_mode_(this->template GetSingleArgument<int>("label_prob", 0)),
label_prob_mode_(
this->template GetSingleArgument<int>("label_prob", 0)),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),
only_loss_(this->template GetSingleArgument<bool>("only_loss", false)),

View File

@ -111,8 +111,9 @@ template <typename Context>
class SpaceBatchOpBase : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SpaceBatchOpBase(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SpaceBatchOpBase(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
pad_(this->template GetSingleArgument<int>("pad", 0)),
pad_t_(this->template GetSingleArgument<int>("pad_t", pad_)),
pad_l_(this->template GetSingleArgument<int>("pad", pad_)),

View File

@ -9,8 +9,9 @@ template <typename T, class Context>
class CAFFE2_API SparseNormalizeOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SparseNormalizeOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SparseNormalizeOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
use_max_norm_(
this->template GetSingleArgument<bool>("use_max_norm", true)),
norm_(this->template GetSingleArgument<float>("norm", 1.0)) {

View File

@ -15,8 +15,9 @@ template <class Context>
class SparseToDenseMaskBase : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SparseToDenseMaskBase(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
template <class... Args>
explicit SparseToDenseMaskBase(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {
std::vector<int64_t> mask =
this->template GetRepeatedArgument<int64_t>("mask");
featuresCount_ = mask.size();
@ -62,8 +63,9 @@ template <class Context>
class SparseToDenseMaskOp : public SparseToDenseMaskBase<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SparseToDenseMaskOp(const OperatorDef& operator_def, Workspace* ws)
: SparseToDenseMaskBase<Context>(operator_def, ws) {
template <class... Args>
explicit SparseToDenseMaskOp(Args&&... args)
: SparseToDenseMaskBase<Context>(std::forward<Args>(args)...) {
returnPresenceMask_ = this->template GetSingleArgument<bool>(
"return_presence_mask", false);
maxSkippedSparseIndices_ =
@ -192,8 +194,9 @@ template <class Context>
class SparseToDenseMaskGradientOp : public SparseToDenseMaskBase<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SparseToDenseMaskGradientOp(const OperatorDef& operator_def, Workspace* ws)
: SparseToDenseMaskBase<Context>(operator_def, ws) {}
template <class... Args>
explicit SparseToDenseMaskGradientOp(Args&&... args)
: SparseToDenseMaskBase<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(

View File

@ -13,8 +13,9 @@ class SparseToDenseOp final : public Operator<Context> {
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_DISPATCH_HELPER;
SparseToDenseOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SparseToDenseOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
output_first_dim_(
this->template GetSingleArgument<int>("output_first_dim", 0)) {}

View File

@ -19,8 +19,9 @@ class SpatialBNOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SpatialBNOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(bool, OpSchema::Arg_IsTest, is_test_, false),
OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5),
OP_SINGLE_ARG(float, "momentum", momentum_, 0.9f),
@ -281,8 +282,9 @@ class SpatialBNGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SpatialBNGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),

View File

@ -11,8 +11,9 @@ namespace caffe2 {
template <typename T, class Context>
class SpatialSoftmaxWithLossOp final : public Operator<Context> {
public:
SpatialSoftmaxWithLossOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SpatialSoftmaxWithLossOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
scale_(this->template GetSingleArgument<float>("scale", 1.)),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))) {
@ -39,8 +40,9 @@ class SpatialSoftmaxWithLossOp final : public Operator<Context> {
template <typename T, class Context>
class SpatialSoftmaxWithLossGradientOp final : public Operator<Context> {
public:
SpatialSoftmaxWithLossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
template <class... Args>
explicit SpatialSoftmaxWithLossGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
scale_(this->template GetSingleArgument<float>("scale", 1.)),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),

View File

@ -13,8 +13,9 @@ class SquareRootDivideOp final : public Operator<Context> {
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_DISPATCH_HELPER;
SquareRootDivideOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit SquareRootDivideOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float>>::call(this, Input(DATA));

View File

@ -8,8 +8,9 @@ namespace caffe2 {
class StatRegistryCreateOp : public Operator<CPUContext> {
public:
StatRegistryCreateOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws) {}
template <class... Args>
explicit StatRegistryCreateOp(Args&&... args)
: Operator(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
*OperatorBase::Output<std::unique_ptr<StatRegistry>>(0) =
@ -20,8 +21,9 @@ class StatRegistryCreateOp : public Operator<CPUContext> {
class StatRegistryExportOp : public Operator<CPUContext> {
public:
StatRegistryExportOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws),
template <class... Args>
explicit StatRegistryExportOp(Args&&... args)
: Operator(std::forward<Args>(args)...),
reset_(GetSingleArgument<bool>("reset", true)) {}
bool RunOnDevice() override {
@ -55,8 +57,9 @@ class StatRegistryExportOp : public Operator<CPUContext> {
class StatRegistryUpdateOp : public Operator<CPUContext> {
public:
StatRegistryUpdateOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws) {}
template <class... Args>
explicit StatRegistryUpdateOp(Args&&... args)
: Operator(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
const auto& keys = Input(0);
@ -118,7 +121,7 @@ class TimerInstance {
};
struct TimerBeginOp : public Operator<CPUContext> {
TimerBeginOp(const OperatorDef& operator_def, Workspace* ws)
explicit TimerBeginOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws),
given_name_(GetSingleArgument<std::string>(
"counter_name",
@ -137,8 +140,8 @@ struct TimerBeginOp : public Operator<CPUContext> {
};
struct TimerEndOp : public Operator<CPUContext> {
TimerEndOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws) {}
template <class... Args>
explicit TimerEndOp(Args&&... args) : Operator(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
OperatorBase::Input<TimerInstance*>(0)->end();
@ -147,8 +150,9 @@ struct TimerEndOp : public Operator<CPUContext> {
};
struct TimerGetAndEndOp : public Operator<CPUContext> {
TimerGetAndEndOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws) {}
template <class... Args>
explicit TimerGetAndEndOp(Args&&... args)
: Operator(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
int64_t nanos = OperatorBase::Input<TimerInstance*>(0)->get_ns();
@ -161,8 +165,8 @@ struct TimerGetAndEndOp : public Operator<CPUContext> {
};
struct TimerGetOp : public Operator<CPUContext> {
TimerGetOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws) {}
template <class... Args>
explicit TimerGetOp(Args&&... args) : Operator(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
int64_t nanos = OperatorBase::Input<TimerInstance*>(0)->get_ns();

View File

@ -8,7 +8,7 @@ namespace caffe2 {
template <typename T>
struct TemplatePutOp : public Operator<CPUContext> {
TemplatePutOp(const OperatorDef& operator_def, Workspace* ws)
explicit TemplatePutOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws),
given_name_(GetSingleArgument<std::string>(
"stat_name",

View File

@ -41,8 +41,9 @@ class StringJoinOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
StringJoinOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit StringJoinOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
delimiter_(
this->template GetSingleArgument<std::string>("delimiter", ",")),
axis_(this->template GetSingleArgument<int>("axis", 0)) {

View File

@ -32,8 +32,9 @@ class StumpFuncOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
StumpFuncOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit StumpFuncOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
threshold_(this->template GetSingleArgument<TIN>("threshold", 0)),
low_value_(this->template GetSingleArgument<TOUT>("low_value", 0)),
high_value_(this->template GetSingleArgument<TOUT>("high_value", 0)) {}
@ -53,8 +54,9 @@ class StumpFuncIndexOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
StumpFuncIndexOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit StumpFuncIndexOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
threshold_(this->template GetSingleArgument<TIN>("threshold", 0)) {}
bool RunOnDevice() override;

View File

@ -69,9 +69,7 @@ class PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp
static constexpr int kNeonNoiseReadSize = kOutputChannels * 16;
USE_OPERATOR_FUNCTIONS(CPUContext);
PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp(
const OperatorDef& operator_def,
Workspace* ws)
explicit PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws), ws_(ws) {}
bool RunOnDevice() override {

View File

@ -14,7 +14,7 @@ constexpr char kSummaryzeOpExtension[] = ".summary";
template <typename T, class Context>
class SummarizeOp final : public Operator<Context> {
public:
SummarizeOp(const OperatorDef& def, Workspace* ws)
explicit SummarizeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
to_file_(this->template GetSingleArgument<int>("to_file", 0)) {
if (to_file_) {

View File

@ -38,8 +38,9 @@ struct TextFileReaderInstance {
class CreateTextFileReaderOp : public Operator<CPUContext> {
public:
CreateTextFileReaderOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
template <class... Args>
explicit CreateTextFileReaderOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
filename_(GetSingleArgument<string>("filename", "")),
numPasses_(GetSingleArgument<int>("num_passes", 1)),
fieldTypes_(GetRepeatedArgument<int>("field_types")) {
@ -86,8 +87,9 @@ inline void convert(
class TextFileReaderReadOp : public Operator<CPUContext> {
public:
TextFileReaderReadOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
template <class... Args>
explicit TextFileReaderReadOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
batchSize_(GetSingleArgument<int>("batch_size", 1)) {}
bool RunOnDevice() override {

View File

@ -12,8 +12,9 @@ template <typename T, class Context>
class ThresholdedReluOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ThresholdedReluOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
template <class... Args>
explicit ThresholdedReluOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {
alpha_ = this->template GetSingleArgument<T>("alpha", 1.0);
}
@ -27,8 +28,9 @@ template <typename T, class Context>
class ThresholdedReluGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ThresholdedReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
template <class... Args>
explicit ThresholdedReluGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {
alpha_ = this->template GetSingleArgument<T>("alpha", 1.0);
}

View File

@ -14,8 +14,9 @@ template <class Context>
class TileOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
TileOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit TileOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
tiles_(this->template GetSingleArgument<int32_t>("tiles", 1)),
axis_(this->template GetSingleArgument<int32_t>("axis", 0)) {}
~TileOp() {}
@ -129,8 +130,9 @@ template <typename T, class Context>
class TileGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
TileGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit TileGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
tiles_(this->template GetSingleArgument<int32_t>("tiles", 1)),
axis_(this->template GetSingleArgument<int32_t>("axis", 0)) {}
~TileGradientOp() {}

View File

@ -12,8 +12,9 @@ class TopKOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
TopKOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit TopKOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "k", k_, -1),
OP_SINGLE_ARG(int, "axis", axis_, -1) {
CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
@ -33,8 +34,9 @@ class TopKGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
TopKGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit TopKGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "axis", axis_, -1) {}
~TopKGradientOp() {}

View File

@ -16,8 +16,9 @@ class TransposeOp final : public Operator<Context> {
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_DISPATCH_HELPER;
TransposeOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit TransposeOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axes_(this->template GetRepeatedArgument<int>("axes")) {
// We will check the legality of axes_: it should be from 0 to axes_.size().
std::vector<int> axes_sorted = axes_;

View File

@ -17,8 +17,9 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
USE_OPERATOR_FUNCTIONS(CUDAContext);
USE_DISPATCH_HELPER;
CuDNNTransposeOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws),
template <class... Args>
explicit CuDNNTransposeOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
cudnn_wrapper_(&context_),
axes_(OperatorBase::GetRepeatedArgument<int>("axes")) {
// We will check the legality of axes_: it should be from 0 to axes_.size().

View File

@ -18,8 +18,9 @@ template <typename T, class Context, class Engine = DefaultEngine>
class TTLinearOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
TTLinearOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit TTLinearOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
inp_sizes_(this->template GetRepeatedArgument<int>("inp_sizes")),
out_sizes_(this->template GetRepeatedArgument<int>("out_sizes")),
tt_ranks_(this->template GetRepeatedArgument<int>("tt_ranks")),
@ -176,8 +177,9 @@ template <typename T, class Context, class Engine = DefaultEngine>
class TTLinearGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
TTLinearGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit TTLinearGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
~TTLinearGradientOp() {}
bool RunOnDevice() override {

View File

@ -24,8 +24,11 @@ namespace caffe2 {
template <typename T, class Context>
class UpsampleBilinearOp final : public Operator<Context> {
public:
UpsampleBilinearOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
template <class... Args>
explicit UpsampleBilinearOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
width_scale_(1),
height_scale_(1) {
if (HasArgument("width_scale")) {
width_scale_ = static_cast<T>(
this->template GetSingleArgument<float>("width_scale", 1));
@ -49,8 +52,11 @@ class UpsampleBilinearOp final : public Operator<Context> {
template <typename T, class Context>
class UpsampleBilinearGradientOp final : public Operator<Context> {
public:
UpsampleBilinearGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
template <class... Args>
explicit UpsampleBilinearGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
width_scale_(1),
height_scale_(1) {
width_scale_ = static_cast<T>(
this->template GetSingleArgument<float>("width_scale", 1));
height_scale_ = static_cast<T>(