#include #include #include #include #include #include namespace torch::jit { namespace prim { using namespace ::c10::prim; } GraphFunction* tryToGraphFunction(Node* n) { if (n->kind() == prim::CallFunction) { AT_ASSERT(n->input(0)->node()->kind() == prim::Constant); auto function_constant = n->input(0)->node(); auto fun_type = function_constant->output()->type()->expect(); return tryToGraphFunction(*fun_type->function()); } if (n->kind() == prim::CallMethod) { const std::string& name = n->s(attr::name); if (auto class_type = n->input(0)->type()->cast()) { Function& function = class_type->getMethod(name); return tryToGraphFunction(function); } } return nullptr; } static void inlineCalls(Block* block) { for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;) { Node* cur = *it++; switch (cur->kind()) { case prim::CallFunction: { if (auto graphFunction = tryToGraphFunction(cur)) { auto function_constant = cur->input(0)->node(); auto fun_type = function_constant->output()->type()->expect(); cur->removeInput(0); GRAPH_UPDATE( "Inlining function '", fun_type->function()->name(), "' to ", *cur); std::shared_ptr g = nullptr; // inline optimized graph for debugging/testing purposes. // we only insert fallback functions in JIT optimized graphs for // execution, not on the Graph that is used for serialization bool fallback = function_constant->hasAttribute(Symbol::attr("fallback")); if (fallback && graphFunction->get_executor().isOptimized()) { auto exec_plans = graphFunction->get_executor().getDebugState().execution_plans; if (!exec_plans.empty()) { g = exec_plans.begin()->second.graph; // optimized_graph() calls Inline, so we only need to explicitly // invoke inlining on the jit optimized graph with recursive // fallback function calls Inline(*g); } } if (g == nullptr) { g = graphFunction->optimized_graph(); } GRAPH_UPDATE("Function body: ", g); inlineCallTo(cur, graphFunction, g.get()); } } break; case prim::CallMethod: { if (auto graphFunction = tryToGraphFunction(cur)) { GRAPH_UPDATE("Inlining method '", cur->s(attr::name), "' to ", *cur); GRAPH_UPDATE("Function body: ", graphFunction->optimized_graph()); inlineCallTo(cur, graphFunction); } } break; default: { for (auto b : cur->blocks()) { inlineCalls(b); } } break; } } } void Inline(Graph& graph) { GRAPH_DUMP("Before Inlining: ", &graph); inlineCalls(graph.block()); GRAPH_DUMP("After Inlining: ", &graph); } } // namespace torch::jit