mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10946 ``` codemod -d . --extensions cc,cpp,cu,cuh,h caffe2/proto/caffe2.pb.h caffe2/proto/caffe2_pb.h ``` Reviewed By: houseroad Differential Revision: D9539945 fbshipit-source-id: 497d04720e8e7e61c05ffe1b23733d0cb774de7e
38 lines
1022 B
C++
38 lines
1022 B
C++
#pragma once
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/transform.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include "caffe2/utils/proto_utils.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
/**
|
|
* Single Op Transform Base class
|
|
*
|
|
* A transform which is applied to a single node, in place.
|
|
*
|
|
* Transforms which derive from SingleOpTransform need to override:
|
|
* ReplaceOperator and MatchOperator.
|
|
*/
|
|
class CAFFE2_API SingleOpTransform : public Transform {
|
|
protected:
|
|
bool PatternRule(
|
|
const transform::Graph& g,
|
|
const std::vector<int>& subgraph,
|
|
int idx) override;
|
|
bool ValidatorRule(
|
|
const transform::Graph& g,
|
|
const std::vector<int>& subgraph) override;
|
|
bool ReplaceRule(const std::vector<int>& subgraph, transform::Graph* g_ptr)
|
|
override;
|
|
|
|
// Specify what the op needs to be to match the pattern.
|
|
virtual bool MatchOperator(const OperatorDef& op) = 0;
|
|
|
|
// Specify how the operator should be replaced.
|
|
virtual void ReplaceOperator(OperatorDef* op) = 0;
|
|
};
|
|
|
|
} // namespace caffe2
|