mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This does 6 things: - add c10/util/Registry.h as the unified registry util - cleaned up some APIs such as export condition - fully remove aten/core/registry.h - fully remove caffe2/core/registry.h - remove a bogus aten/registry.h - unifying all macros - set up registry testing in c10 Also, an important note that we used to mark the templated Registry class as EXPORT - this should not happen, because one should almost never export a template class. This PR fixes that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12077 Reviewed By: ezyang Differential Revision: D10050771 Pulled By: Yangqing fbshipit-source-id: 417b249b49fed6a67956e7c6b6d22374bcee24cf
76 lines
2.6 KiB
C++
76 lines
2.6 KiB
C++
#ifndef CAFFE2_OPT_OPT_PASSS_H
|
|
#define CAFFE2_OPT_OPT_PASSS_H
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/workspace.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
|
|
#include "nomnigraph/Representations/NeuralNet.h"
|
|
|
|
using namespace nom::repr;
|
|
|
|
namespace caffe2 {
|
|
|
|
/* This file sets up the optimization pass registry.
|
|
*
|
|
* You'll want to either create a class that inherits from OptimizationPass
|
|
* and implements run or use the REGISTER_OPT_PASS_FROM_FUNC(name, func)
|
|
* to register a function that takes in an NNModule*.
|
|
*
|
|
* If you need access to the workspace in the optimization you'll need to
|
|
* use a different registry and inherit from WorkspaceOptimizationPass.
|
|
*/
|
|
|
|
class CAFFE2_API OptimizationPass {
|
|
public:
|
|
OptimizationPass(NNModule* nn) : nn_(nn) {}
|
|
virtual void run() = 0;
|
|
virtual ~OptimizationPass() {}
|
|
|
|
protected:
|
|
NNModule* nn_;
|
|
};
|
|
|
|
class CAFFE2_API WorkspaceOptimizationPass : public OptimizationPass {
|
|
public:
|
|
WorkspaceOptimizationPass(NNModule* nn, Workspace* ws) : OptimizationPass(nn), ws_(ws) {}
|
|
virtual ~WorkspaceOptimizationPass() {}
|
|
|
|
protected:
|
|
Workspace* ws_;
|
|
};
|
|
|
|
C10_DECLARE_REGISTRY(
|
|
WorkspaceOptimizationPassRegistry,
|
|
WorkspaceOptimizationPass,
|
|
NNModule*,
|
|
Workspace*);
|
|
#define REGISTER_WS_OPT_PASS(clsname) \
|
|
C10_REGISTER_CLASS(WorkspaceOptimizationPassRegistry, clsname, clsname)
|
|
#define REGISTER_WS_OPT_PASS_FROM_FUNC(passname, funcname) \
|
|
class passname : public WorkspaceOptimizationPass { \
|
|
public: \
|
|
using WorkspaceOptimizationPass::WorkspaceOptimizationPass; \
|
|
void run() override { \
|
|
funcname(nn_, ws_); \
|
|
} \
|
|
}; \
|
|
REGISTER_WS_OPT_PASS(passname);
|
|
|
|
C10_DECLARE_REGISTRY(OptimizationPassRegistry, OptimizationPass, NNModule*);
|
|
#define REGISTER_OPT_PASS(clsname) \
|
|
C10_REGISTER_CLASS(OptimizationPassRegistry, clsname, clsname)
|
|
#define REGISTER_OPT_PASS_FROM_FUNC(passname, funcname) \
|
|
class passname : public OptimizationPass { \
|
|
public: \
|
|
using OptimizationPass::OptimizationPass; \
|
|
void run() override { \
|
|
funcname(nn_); \
|
|
} \
|
|
}; \
|
|
REGISTER_OPT_PASS(passname);
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPT_OPT_PASSS_H
|