pytorch/caffe2/opt/passes.h
Yangqing Jia 9c49bb9ddf Move registry fully to c10 (#12077)
Summary:
This does 6 things:

- add c10/util/Registry.h as the unified registry util
  - cleaned up some APIs such as export condition
- fully remove aten/core/registry.h
- fully remove caffe2/core/registry.h
- remove a bogus aten/registry.h
- unifying all macros
- set up registry testing in c10

Also, an important note that we used to mark the templated Registry class as EXPORT - this should not happen, because one should almost never export a template class. This PR fixes that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12077

Reviewed By: ezyang

Differential Revision: D10050771

Pulled By: Yangqing

fbshipit-source-id: 417b249b49fed6a67956e7c6b6d22374bcee24cf
2018-09-27 03:09:54 -07:00

76 lines
2.6 KiB
C++

#ifndef CAFFE2_OPT_OPT_PASSS_H
#define CAFFE2_OPT_OPT_PASSS_H
#include "caffe2/core/common.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "nomnigraph/Representations/NeuralNet.h"
using namespace nom::repr;
namespace caffe2 {
/* This file sets up the optimization pass registry.
*
* You'll want to either create a class that inherits from OptimizationPass
* and implements run or use the REGISTER_OPT_PASS_FROM_FUNC(name, func)
* to register a function that takes in an NNModule*.
*
* If you need access to the workspace in the optimization you'll need to
* use a different registry and inherit from WorkspaceOptimizationPass.
*/
class CAFFE2_API OptimizationPass {
public:
OptimizationPass(NNModule* nn) : nn_(nn) {}
virtual void run() = 0;
virtual ~OptimizationPass() {}
protected:
NNModule* nn_;
};
class CAFFE2_API WorkspaceOptimizationPass : public OptimizationPass {
public:
WorkspaceOptimizationPass(NNModule* nn, Workspace* ws) : OptimizationPass(nn), ws_(ws) {}
virtual ~WorkspaceOptimizationPass() {}
protected:
Workspace* ws_;
};
C10_DECLARE_REGISTRY(
WorkspaceOptimizationPassRegistry,
WorkspaceOptimizationPass,
NNModule*,
Workspace*);
#define REGISTER_WS_OPT_PASS(clsname) \
C10_REGISTER_CLASS(WorkspaceOptimizationPassRegistry, clsname, clsname)
#define REGISTER_WS_OPT_PASS_FROM_FUNC(passname, funcname) \
class passname : public WorkspaceOptimizationPass { \
public: \
using WorkspaceOptimizationPass::WorkspaceOptimizationPass; \
void run() override { \
funcname(nn_, ws_); \
} \
}; \
REGISTER_WS_OPT_PASS(passname);
C10_DECLARE_REGISTRY(OptimizationPassRegistry, OptimizationPass, NNModule*);
#define REGISTER_OPT_PASS(clsname) \
C10_REGISTER_CLASS(OptimizationPassRegistry, clsname, clsname)
#define REGISTER_OPT_PASS_FROM_FUNC(passname, funcname) \
class passname : public OptimizationPass { \
public: \
using OptimizationPass::OptimizationPass; \
void run() override { \
funcname(nn_); \
} \
}; \
REGISTER_OPT_PASS(passname);
} // namespace caffe2
#endif // CAFFE2_OPT_OPT_PASSS_H