Add python bindings for JIT trace.

Is it a hack?
No, we should provide python bindings for that.

[ghstack-poisoned]
This commit is contained in:
Mikhail Zolotukhin 2021-10-09 14:05:57 -07:00
parent b2f0fc9656
commit 03c5598ba2

View File

@ -91,6 +91,7 @@
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/jit_trace.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/print_handler.h>
#include <torch/csrc/jit/runtime/static/init.h>
@ -508,6 +509,22 @@ void initJITBindings(PyObject* module) {
},
py::doc(
"Interpret a JIT graph with given inputs without running any optimization passes on it"))
.def(
"_jit_trace_graph",
[](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
Stack stack;
stack.reserve(inputs.size()); // captures?
for (auto& obj : inputs) {
stack.push_back(toTypeInferredIValue(obj));
}
auto g_inputs = graph->inputs();
for (const auto i : c10::irange(inputs.size())) {
if (stack[i].isTensor()) {
g_inputs[i]->setType(stack[i].type());
}
}
return TraceGraph(graph, stack);
})
.def("_jit_pass_remove_expands", RemoveExpands)
.def("_jit_pass_erase_number_types", EraseNumberTypes)
.def("_jit_pass_inline_fork_wait", InlineForkWait)