mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add python bindings for JIT trace.
Is it a hack? No, we should provide python bindings for that. [ghstack-poisoned]
This commit is contained in:
parent
b2f0fc9656
commit
03c5598ba2
|
|
@ -91,6 +91,7 @@
|
|||
#include <torch/csrc/jit/runtime/autodiff.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <torch/csrc/jit/runtime/jit_trace.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/print_handler.h>
|
||||
#include <torch/csrc/jit/runtime/static/init.h>
|
||||
|
|
@ -508,6 +509,22 @@ void initJITBindings(PyObject* module) {
|
|||
},
|
||||
py::doc(
|
||||
"Interpret a JIT graph with given inputs without running any optimization passes on it"))
|
||||
.def(
|
||||
"_jit_trace_graph",
|
||||
[](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
|
||||
Stack stack;
|
||||
stack.reserve(inputs.size()); // captures?
|
||||
for (auto& obj : inputs) {
|
||||
stack.push_back(toTypeInferredIValue(obj));
|
||||
}
|
||||
auto g_inputs = graph->inputs();
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
if (stack[i].isTensor()) {
|
||||
g_inputs[i]->setType(stack[i].type());
|
||||
}
|
||||
}
|
||||
return TraceGraph(graph, stack);
|
||||
})
|
||||
.def("_jit_pass_remove_expands", RemoveExpands)
|
||||
.def("_jit_pass_erase_number_types", EraseNumberTypes)
|
||||
.def("_jit_pass_inline_fork_wait", InlineForkWait)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user