From 03c5598ba2586379c2293031b1754d437f5cd645 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Sat, 9 Oct 2021 14:05:57 -0700 Subject: [PATCH] Add python bindings for JIT trace. Is it a hack? No, we should provide python bindings for that. [ghstack-poisoned] --- torch/csrc/jit/python/init.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index dd263ccbd00..b13482c6f90 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -91,6 +91,7 @@ #include #include #include +#include #include #include #include @@ -508,6 +509,22 @@ void initJITBindings(PyObject* module) { }, py::doc( "Interpret a JIT graph with given inputs without running any optimization passes on it")) + .def( + "_jit_trace_graph", + [](std::shared_ptr& graph, const py::tuple& inputs) { + Stack stack; + stack.reserve(inputs.size()); // captures? + for (auto& obj : inputs) { + stack.push_back(toTypeInferredIValue(obj)); + } + auto g_inputs = graph->inputs(); + for (const auto i : c10::irange(inputs.size())) { + if (stack[i].isTensor()) { + g_inputs[i]->setType(stack[i].type()); + } + } + return TraceGraph(graph, stack); + }) .def("_jit_pass_remove_expands", RemoveExpands) .def("_jit_pass_erase_number_types", EraseNumberTypes) .def("_jit_pass_inline_fork_wait", InlineForkWait)