pytorch/caffe2/operators/batch_box_cox_op.h
efiks 2e4c89eba9 [torch] Unify batch_box_cox implementations into perfkernels folder (#86569)
Summary:
1) Adding MKL/AVX2 based implementation into perfkernels. This implementation is similar to caffe2/operators/batch_box_cox_op.cc
2) Migrating batch_box_cox_op of caffe2 use this implementation

Test Plan: CI

Differential Revision: D40208074

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86569
Approved by: https://github.com/hyuen
2022-10-23 19:29:25 +00:00

40 lines
1015 B
C++

#ifndef CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
#define CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(BatchBoxCox);
namespace caffe2 {
template <class Context>
class BatchBoxCoxOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit BatchBoxCoxOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
min_block_size_(
this->template GetSingleArgument<int>("min_block_size", 256)) {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, double>>::call(this, Input(DATA));
}
template <typename T>
bool DoRunWithType();
protected:
std::size_t min_block_size_;
INPUT_TAGS(DATA, LAMBDA1, LAMBDA2);
};
} // namespace caffe2
#endif // CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_