mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
refactor caffe2 operator constructors - 8/9 (#17089)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17089 clangr codemod Reviewed By: ezyang Differential Revision: D14078539 fbshipit-source-id: 9ca196af4af7f26fc82e6cf82b35d478d0597752
This commit is contained in:
parent
28b5df1c8f
commit
7413f0926a
|
|
@ -202,8 +202,9 @@ template <class Context>
|
|||
class SliceOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SliceOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SliceOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
|
||||
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
|
||||
statically_inited_(false) {}
|
||||
|
|
@ -263,8 +264,9 @@ template <class Context>
|
|||
class SliceGradientOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SliceGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SliceGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
|
||||
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
|
||||
statically_inited_(false) {}
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class SoftmaxOp final : public Operator<Context> {
|
||||
public:
|
||||
SoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SoftmaxOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
bool RunOnDevice() override;
|
||||
|
|
@ -27,8 +28,9 @@ class SoftmaxOp final : public Operator<Context> {
|
|||
template <typename T, class Context>
|
||||
class SoftmaxGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
SoftmaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SoftmaxGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
bool RunOnDevice() override;
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ constexpr int TOP_GRADIENT_DESC_ID = 2;
|
|||
|
||||
class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
|
||||
public:
|
||||
explicit CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<CUDAContext>(def, ws),
|
||||
template <class... Args>
|
||||
explicit CuDNNSoftmaxOp(Args&&... args)
|
||||
: Operator<CUDAContext>(std::forward<Args>(args)...),
|
||||
cudnn_wrapper_(&context_),
|
||||
axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
|
||||
|
|
@ -77,8 +78,9 @@ class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
|
|||
|
||||
class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
|
||||
public:
|
||||
explicit CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<CUDAContext>(def, ws),
|
||||
template <class... Args>
|
||||
explicit CuDNNSoftmaxGradientOp(Args&&... args)
|
||||
: Operator<CUDAContext>(std::forward<Args>(args)...),
|
||||
cudnn_wrapper_(&context_),
|
||||
axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
|
||||
|
|
|
|||
|
|
@ -11,10 +11,12 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class SoftmaxWithLossOp final : public Operator<Context> {
|
||||
public:
|
||||
SoftmaxWithLossOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SoftmaxWithLossOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
scale_(this->template GetSingleArgument<float>("scale", 1.)),
|
||||
label_prob_mode_(this->template GetSingleArgument<int>("label_prob", 0)),
|
||||
label_prob_mode_(
|
||||
this->template GetSingleArgument<int>("label_prob", 0)),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||
axis_(this->template GetSingleArgument<int>("axis", 1)) {
|
||||
|
|
@ -44,10 +46,12 @@ class SoftmaxWithLossOp final : public Operator<Context> {
|
|||
template <typename T, class Context>
|
||||
class SoftmaxWithLossGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
SoftmaxWithLossGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
template <class... Args>
|
||||
explicit SoftmaxWithLossGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
scale_(this->template GetSingleArgument<float>("scale", 1.)),
|
||||
label_prob_mode_(this->template GetSingleArgument<int>("label_prob", 0)),
|
||||
label_prob_mode_(
|
||||
this->template GetSingleArgument<int>("label_prob", 0)),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||
only_loss_(this->template GetSingleArgument<bool>("only_loss", false)),
|
||||
|
|
|
|||
|
|
@ -111,8 +111,9 @@ template <typename Context>
|
|||
class SpaceBatchOpBase : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SpaceBatchOpBase(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SpaceBatchOpBase(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
pad_(this->template GetSingleArgument<int>("pad", 0)),
|
||||
pad_t_(this->template GetSingleArgument<int>("pad_t", pad_)),
|
||||
pad_l_(this->template GetSingleArgument<int>("pad", pad_)),
|
||||
|
|
|
|||
|
|
@ -9,8 +9,9 @@ template <typename T, class Context>
|
|||
class CAFFE2_API SparseNormalizeOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SparseNormalizeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SparseNormalizeOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
use_max_norm_(
|
||||
this->template GetSingleArgument<bool>("use_max_norm", true)),
|
||||
norm_(this->template GetSingleArgument<float>("norm", 1.0)) {
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ template <class Context>
|
|||
class SparseToDenseMaskBase : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SparseToDenseMaskBase(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit SparseToDenseMaskBase(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {
|
||||
std::vector<int64_t> mask =
|
||||
this->template GetRepeatedArgument<int64_t>("mask");
|
||||
featuresCount_ = mask.size();
|
||||
|
|
@ -62,8 +63,9 @@ template <class Context>
|
|||
class SparseToDenseMaskOp : public SparseToDenseMaskBase<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SparseToDenseMaskOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: SparseToDenseMaskBase<Context>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit SparseToDenseMaskOp(Args&&... args)
|
||||
: SparseToDenseMaskBase<Context>(std::forward<Args>(args)...) {
|
||||
returnPresenceMask_ = this->template GetSingleArgument<bool>(
|
||||
"return_presence_mask", false);
|
||||
maxSkippedSparseIndices_ =
|
||||
|
|
@ -192,8 +194,9 @@ template <class Context>
|
|||
class SparseToDenseMaskGradientOp : public SparseToDenseMaskBase<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SparseToDenseMaskGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: SparseToDenseMaskBase<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit SparseToDenseMaskGradientOp(Args&&... args)
|
||||
: SparseToDenseMaskBase<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ class SparseToDenseOp final : public Operator<Context> {
|
|||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
USE_DISPATCH_HELPER;
|
||||
|
||||
SparseToDenseOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SparseToDenseOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
output_first_dim_(
|
||||
this->template GetSingleArgument<int>("output_first_dim", 0)) {}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@ class SpatialBNOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
SpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SpatialBNOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(bool, OpSchema::Arg_IsTest, is_test_, false),
|
||||
OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5),
|
||||
OP_SINGLE_ARG(float, "momentum", momentum_, 0.9f),
|
||||
|
|
@ -281,8 +282,9 @@ class SpatialBNGradientOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
SpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SpatialBNGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class SpatialSoftmaxWithLossOp final : public Operator<Context> {
|
||||
public:
|
||||
SpatialSoftmaxWithLossOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SpatialSoftmaxWithLossOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
scale_(this->template GetSingleArgument<float>("scale", 1.)),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))) {
|
||||
|
|
@ -39,8 +40,9 @@ class SpatialSoftmaxWithLossOp final : public Operator<Context> {
|
|||
template <typename T, class Context>
|
||||
class SpatialSoftmaxWithLossGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
SpatialSoftmaxWithLossGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
template <class... Args>
|
||||
explicit SpatialSoftmaxWithLossGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
scale_(this->template GetSingleArgument<float>("scale", 1.)),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ class SquareRootDivideOp final : public Operator<Context> {
|
|||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
USE_DISPATCH_HELPER;
|
||||
|
||||
SquareRootDivideOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit SquareRootDivideOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<TensorTypes<float>>::call(this, Input(DATA));
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ namespace caffe2 {
|
|||
|
||||
class StatRegistryCreateOp : public Operator<CPUContext> {
|
||||
public:
|
||||
StatRegistryCreateOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit StatRegistryCreateOp(Args&&... args)
|
||||
: Operator(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
*OperatorBase::Output<std::unique_ptr<StatRegistry>>(0) =
|
||||
|
|
@ -20,8 +21,9 @@ class StatRegistryCreateOp : public Operator<CPUContext> {
|
|||
|
||||
class StatRegistryExportOp : public Operator<CPUContext> {
|
||||
public:
|
||||
StatRegistryExportOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit StatRegistryExportOp(Args&&... args)
|
||||
: Operator(std::forward<Args>(args)...),
|
||||
reset_(GetSingleArgument<bool>("reset", true)) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
@ -55,8 +57,9 @@ class StatRegistryExportOp : public Operator<CPUContext> {
|
|||
|
||||
class StatRegistryUpdateOp : public Operator<CPUContext> {
|
||||
public:
|
||||
StatRegistryUpdateOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit StatRegistryUpdateOp(Args&&... args)
|
||||
: Operator(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
const auto& keys = Input(0);
|
||||
|
|
@ -118,7 +121,7 @@ class TimerInstance {
|
|||
};
|
||||
|
||||
struct TimerBeginOp : public Operator<CPUContext> {
|
||||
TimerBeginOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
explicit TimerBeginOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws),
|
||||
given_name_(GetSingleArgument<std::string>(
|
||||
"counter_name",
|
||||
|
|
@ -137,8 +140,8 @@ struct TimerBeginOp : public Operator<CPUContext> {
|
|||
};
|
||||
|
||||
struct TimerEndOp : public Operator<CPUContext> {
|
||||
TimerEndOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit TimerEndOp(Args&&... args) : Operator(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
OperatorBase::Input<TimerInstance*>(0)->end();
|
||||
|
|
@ -147,8 +150,9 @@ struct TimerEndOp : public Operator<CPUContext> {
|
|||
};
|
||||
|
||||
struct TimerGetAndEndOp : public Operator<CPUContext> {
|
||||
TimerGetAndEndOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit TimerGetAndEndOp(Args&&... args)
|
||||
: Operator(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
int64_t nanos = OperatorBase::Input<TimerInstance*>(0)->get_ns();
|
||||
|
|
@ -161,8 +165,8 @@ struct TimerGetAndEndOp : public Operator<CPUContext> {
|
|||
};
|
||||
|
||||
struct TimerGetOp : public Operator<CPUContext> {
|
||||
TimerGetOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit TimerGetOp(Args&&... args) : Operator(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
int64_t nanos = OperatorBase::Input<TimerInstance*>(0)->get_ns();
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ namespace caffe2 {
|
|||
|
||||
template <typename T>
|
||||
struct TemplatePutOp : public Operator<CPUContext> {
|
||||
TemplatePutOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
explicit TemplatePutOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws),
|
||||
given_name_(GetSingleArgument<std::string>(
|
||||
"stat_name",
|
||||
|
|
|
|||
|
|
@ -41,8 +41,9 @@ class StringJoinOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
StringJoinOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit StringJoinOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
delimiter_(
|
||||
this->template GetSingleArgument<std::string>("delimiter", ",")),
|
||||
axis_(this->template GetSingleArgument<int>("axis", 0)) {
|
||||
|
|
|
|||
|
|
@ -32,8 +32,9 @@ class StumpFuncOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
StumpFuncOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit StumpFuncOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
threshold_(this->template GetSingleArgument<TIN>("threshold", 0)),
|
||||
low_value_(this->template GetSingleArgument<TOUT>("low_value", 0)),
|
||||
high_value_(this->template GetSingleArgument<TOUT>("high_value", 0)) {}
|
||||
|
|
@ -53,8 +54,9 @@ class StumpFuncIndexOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
StumpFuncIndexOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit StumpFuncIndexOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
threshold_(this->template GetSingleArgument<TIN>("threshold", 0)) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
|
|
|||
|
|
@ -69,9 +69,7 @@ class PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp
|
|||
static constexpr int kNeonNoiseReadSize = kOutputChannels * 16;
|
||||
|
||||
USE_OPERATOR_FUNCTIONS(CPUContext);
|
||||
PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp(
|
||||
const OperatorDef& operator_def,
|
||||
Workspace* ws)
|
||||
explicit PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws), ws_(ws) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ constexpr char kSummaryzeOpExtension[] = ".summary";
|
|||
template <typename T, class Context>
|
||||
class SummarizeOp final : public Operator<Context> {
|
||||
public:
|
||||
SummarizeOp(const OperatorDef& def, Workspace* ws)
|
||||
explicit SummarizeOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
to_file_(this->template GetSingleArgument<int>("to_file", 0)) {
|
||||
if (to_file_) {
|
||||
|
|
|
|||
|
|
@ -38,8 +38,9 @@ struct TextFileReaderInstance {
|
|||
|
||||
class CreateTextFileReaderOp : public Operator<CPUContext> {
|
||||
public:
|
||||
CreateTextFileReaderOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit CreateTextFileReaderOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
filename_(GetSingleArgument<string>("filename", "")),
|
||||
numPasses_(GetSingleArgument<int>("num_passes", 1)),
|
||||
fieldTypes_(GetRepeatedArgument<int>("field_types")) {
|
||||
|
|
@ -86,8 +87,9 @@ inline void convert(
|
|||
|
||||
class TextFileReaderReadOp : public Operator<CPUContext> {
|
||||
public:
|
||||
TextFileReaderReadOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit TextFileReaderReadOp(Args&&... args)
|
||||
: Operator<CPUContext>(std::forward<Args>(args)...),
|
||||
batchSize_(GetSingleArgument<int>("batch_size", 1)) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ template <typename T, class Context>
|
|||
class ThresholdedReluOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
ThresholdedReluOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit ThresholdedReluOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {
|
||||
alpha_ = this->template GetSingleArgument<T>("alpha", 1.0);
|
||||
}
|
||||
|
||||
|
|
@ -27,8 +28,9 @@ template <typename T, class Context>
|
|||
class ThresholdedReluGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
ThresholdedReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit ThresholdedReluGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {
|
||||
alpha_ = this->template GetSingleArgument<T>("alpha", 1.0);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ template <class Context>
|
|||
class TileOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
TileOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit TileOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
tiles_(this->template GetSingleArgument<int32_t>("tiles", 1)),
|
||||
axis_(this->template GetSingleArgument<int32_t>("axis", 0)) {}
|
||||
~TileOp() {}
|
||||
|
|
@ -129,8 +130,9 @@ template <typename T, class Context>
|
|||
class TileGradientOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
TileGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit TileGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
tiles_(this->template GetSingleArgument<int32_t>("tiles", 1)),
|
||||
axis_(this->template GetSingleArgument<int32_t>("axis", 0)) {}
|
||||
~TileGradientOp() {}
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ class TopKOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
TopKOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit TopKOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(int, "k", k_, -1),
|
||||
OP_SINGLE_ARG(int, "axis", axis_, -1) {
|
||||
CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
|
||||
|
|
@ -33,8 +34,9 @@ class TopKGradientOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
TopKGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit TopKGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(int, "axis", axis_, -1) {}
|
||||
|
||||
~TopKGradientOp() {}
|
||||
|
|
|
|||
|
|
@ -16,8 +16,9 @@ class TransposeOp final : public Operator<Context> {
|
|||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
USE_DISPATCH_HELPER;
|
||||
|
||||
TransposeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit TransposeOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
axes_(this->template GetRepeatedArgument<int>("axes")) {
|
||||
// We will check the legality of axes_: it should be from 0 to axes_.size().
|
||||
std::vector<int> axes_sorted = axes_;
|
||||
|
|
|
|||
|
|
@ -17,8 +17,9 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
|
|||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
USE_DISPATCH_HELPER;
|
||||
|
||||
CuDNNTransposeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CUDAContext>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit CuDNNTransposeOp(Args&&... args)
|
||||
: Operator<CUDAContext>(std::forward<Args>(args)...),
|
||||
cudnn_wrapper_(&context_),
|
||||
axes_(OperatorBase::GetRepeatedArgument<int>("axes")) {
|
||||
// We will check the legality of axes_: it should be from 0 to axes_.size().
|
||||
|
|
|
|||
|
|
@ -18,8 +18,9 @@ template <typename T, class Context, class Engine = DefaultEngine>
|
|||
class TTLinearOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
TTLinearOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit TTLinearOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
inp_sizes_(this->template GetRepeatedArgument<int>("inp_sizes")),
|
||||
out_sizes_(this->template GetRepeatedArgument<int>("out_sizes")),
|
||||
tt_ranks_(this->template GetRepeatedArgument<int>("tt_ranks")),
|
||||
|
|
@ -176,8 +177,9 @@ template <typename T, class Context, class Engine = DefaultEngine>
|
|||
class TTLinearGradientOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
TTLinearGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit TTLinearGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
~TTLinearGradientOp() {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -24,8 +24,11 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class UpsampleBilinearOp final : public Operator<Context> {
|
||||
public:
|
||||
UpsampleBilinearOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
|
||||
template <class... Args>
|
||||
explicit UpsampleBilinearOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
width_scale_(1),
|
||||
height_scale_(1) {
|
||||
if (HasArgument("width_scale")) {
|
||||
width_scale_ = static_cast<T>(
|
||||
this->template GetSingleArgument<float>("width_scale", 1));
|
||||
|
|
@ -49,8 +52,11 @@ class UpsampleBilinearOp final : public Operator<Context> {
|
|||
template <typename T, class Context>
|
||||
class UpsampleBilinearGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
UpsampleBilinearGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
|
||||
template <class... Args>
|
||||
explicit UpsampleBilinearGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
width_scale_(1),
|
||||
height_scale_(1) {
|
||||
width_scale_ = static_cast<T>(
|
||||
this->template GetSingleArgument<float>("width_scale", 1));
|
||||
height_scale_ = static_cast<T>(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user