#include "torch/csrc/utils/pybind.h" #include "torch/csrc/jit/python_tracer.h" #include "torch/csrc/jit/python_ir.h" #include "torch/csrc/jit/python_arg_flatten.h" #include "torch/csrc/jit/export.h" #include "torch/csrc/jit/python_compiled_function.h" #include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/onnx.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/common_subexpression_elimination.h" #include "torch/csrc/jit/passes/peephole.h" #include "torch/csrc/jit/passes/canonicalize.h" #include "torch/csrc/jit/passes/onnx/peephole.h" namespace torch { namespace jit { namespace { bool loadPythonClasses() { // Leaving this code here, because it will likely be useful at some point //PyObject *jit_module = PyImport_ImportModule("torch.jit"); //THPUtils_assert(jit_module, "class loader couldn't access " //"torch.jit module"); //PyObject *jit_dict = PyModule_GetDict(jit_module); return true; } template& graph)> void graph_pass(const std::shared_ptr& state) { return F(state->graph); } } // anonymous namespace extern std::string runJITCPPTests(); void initJITBindings(PyObject *module) { auto m = py::handle(module).cast(); py::class_(m, "IODescriptor"); m.def("_jit_init", loadPythonClasses) .def("_jit_pass_onnx", ToONNX) .def("_jit_pass_onnx_peephole", graph_pass) .def("_jit_pass_fuse", graph_pass) .def("_jit_pass_dce", graph_pass) .def("_jit_pass_cse", graph_pass) .def("_jit_pass_peephole", graph_pass) .def("_jit_pass_canonicalize", graph_pass) .def("_jit_pass_lint", graph_pass) .def("_jit_run_cpp_tests", runJITCPPTests) .def("_jit_flatten", [](py::handle& obj) { auto res = python::flatten(obj); return std::make_pair(res.vars, res.desc); }) .def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) { return py::reinterpret_steal(python::unflatten(vars, desc)); }); initPythonIRBindings(module); initPythonTracerBindings(module); python::initCompilerMixin(module); } }}