pytorch/torch/csrc/jit/init.cpp
Zachary DeVito 0ae5498079 [JIT] add create_autodiff_subgraphs (#4822)
This pass splits differentiable subgraphs into their own Node,
similar to a fusion group.

This initial implementation does not create optimal subgraphs, but
it works well in the case where most things are differentiable,
and has the building blocks (`mergeNodes`) to extend to the
better implementation.
2018-01-23 23:46:54 -05:00

70 lines
2.3 KiB
C++

#include "torch/csrc/utils/pybind.h"
#include "torch/csrc/jit/python_tracer.h"
#include "torch/csrc/jit/python_ir.h"
#include "torch/csrc/jit/python_arg_flatten.h"
#include "torch/csrc/jit/export.h"
#include "torch/csrc/jit/python_compiled_function.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/onnx.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/passes/canonicalize.h"
#include "torch/csrc/jit/passes/onnx/peephole.h"
namespace torch { namespace jit {
namespace {
bool loadPythonClasses() {
// Leaving this code here, because it will likely be useful at some point
//PyObject *jit_module = PyImport_ImportModule("torch.jit");
//THPUtils_assert(jit_module, "class loader couldn't access "
//"torch.jit module");
//PyObject *jit_dict = PyModule_GetDict(jit_module);
return true;
}
template<void (*F)(std::shared_ptr<Graph>& graph)>
void graph_pass(const std::shared_ptr<tracer::TracingState>& state) {
return F(state->graph);
}
} // anonymous namespace
extern std::string runJITCPPTests();
void initJITBindings(PyObject *module) {
auto m = py::handle(module).cast<py::module>();
py::class_<python::IODescriptor>(m, "IODescriptor");
m.def("_jit_init", loadPythonClasses)
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_onnx_peephole", graph_pass<PeepholeOptimizeONNX>)
.def("_jit_pass_fuse", graph_pass<FuseGraph>)
.def("_jit_pass_dce", graph_pass<EliminateDeadCode>)
.def("_jit_pass_cse", graph_pass<EliminateCommonSubexpression>)
.def("_jit_pass_peephole", graph_pass<PeepholeOptimize>)
.def("_jit_pass_canonicalize", graph_pass<Canonicalize>)
.def("_jit_pass_lint", graph_pass<LintGraph>)
.def("_jit_run_cpp_tests", runJITCPPTests)
.def("_jit_flatten", [](py::handle& obj) {
auto res = python::flatten(obj);
return std::make_pair(res.vars, res.desc);
})
.def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) {
return py::reinterpret_steal<py::object>(python::unflatten(vars, desc));
});
initPythonIRBindings(module);
initPythonTracerBindings(module);
python::initCompilerMixin(module);
}
}}