#include #include #include namespace torch::jit { // Canonicalize a graph, renumbering it so that all structurally equivalent // graphs have same numbers. // keep_unique_names: If false, canonicalizes unique names by removing them // and replacing them with normal value names. // Otherwise, ignores values with unique names. std::shared_ptr Canonicalize( const std::shared_ptr& graph, bool keep_unique_names) { auto r = std::make_shared(graph->current_scope()); std::unordered_map rn_env; auto rn_fn = [&](Value* v) { return rn_env.at(v); }; for (auto* input : graph->inputs()) { auto* r_input = r->addInput(); r_input->copyMetadata(input); if (!keep_unique_names) r_input->setDebugName(""); rn_env[input] = r_input; } for (auto* node : graph->nodes()) { auto* r_node = r->createClone(node, rn_fn); if (!keep_unique_names) { for (auto* output : r_node->outputs()) { output->setDebugName(""); } } r->appendNode(r_node); auto outputs = node->outputs(); auto r_outputs = r_node->outputs(); for (const auto i : c10::irange(outputs.size())) { rn_env[outputs.at(i)] = r_outputs.at(i); } if (node->hasAttribute(attr::Subgraph)) { r_node->g_( attr::Subgraph, Canonicalize(node->g(attr::Subgraph), keep_unique_names)); } } for (auto* output : graph->outputs()) { r->registerOutput(rn_fn(output)); } return r; } // Which index in b's owning Node is b static size_t blockIndex(const Block* b) { auto n = b->owningNode(); AT_ASSERT(n); for (size_t i = 0; i < n->blocks().size(); ++i) { if (n->blocks()[i] == b) { return i; } } AT_ASSERT(false); } /* * This establishes a canonical ordering of nodes. * If n1 and n2 are in the same block, whichever node appears first * is before the other. * If n1 and n2 are contained in different blocks of an if node, * then whichever block is in the true block is ordered before the other. * If n1 contains n2, then n1 is before n2. This has the nice property that * whichever node appears first in a dump of the graph is before the other. * NB: this is not a topological index. Topologically, two nodes in * different blocks of an if node are not topologically < or > each other. */ static bool isBefore(Node* n1, Node* n2) { // Invalid to call with the same node as both args AT_ASSERT(n1 != n2); // Set n1 and n2 to be the number of blocks from the Graph block size_t d_1 = n1->blocksFromGraphBlock(); size_t d_2 = n2->blocksFromGraphBlock(); for (; d_1 > d_2; --d_1) { n1 = n1->owningBlock()->owningNode(); // n2 contains n1 if (n1 == n2) { return false; } } for (; d_2 > d_1; --d_2) { n2 = n2->owningBlock()->owningNode(); // n1 contains n2 if (n2 == n1) { return true; } } // Now they are the same number of blocks from the graph block, // recurse upwards, checking if they are on the same block while (true) { if (n1->owningBlock() == n2->owningBlock()) { return n1->isBefore(n2); } auto new_n1 = n1->owningBlock()->owningNode(); auto new_n2 = n2->owningBlock()->owningNode(); AT_ASSERT(new_n1 != nullptr); AT_ASSERT(new_n2 != nullptr); if (new_n1 == new_n2) { // take whichever node is in the earlier block auto index_1 = blockIndex(n1->owningBlock()); auto index_2 = blockIndex(n2->owningBlock()); return index_1 < index_2; } n1 = new_n1; n2 = new_n2; } } static bool isBefore(const Use& a, const Use& b) { // If two uses are the same node, we order on offset if (a.user == b.user) { return a.offset < b.offset; } return isBefore(a.user, b.user); } static bool isAfter(const Use& a, const Use& b) { if (a.user == b.user && a.offset == b.offset) { return false; } return !isBefore(a, b); } bool isBeforeOrAfter(const Use& a, const Use& b, bool checking_before) { return checking_before ? isBefore(a, b) : isAfter(a, b); } std::optional firstOrLastUse(Value* v, bool find_first) { if (v->uses().empty()) { return std::nullopt; } Use extreme_use = v->uses()[0]; for (size_t i = 1; i < v->uses().size(); ++i) { auto n_use = v->uses()[i]; if (!isBeforeOrAfter(extreme_use, n_use, find_first)) { extreme_use = n_use; } } return extreme_use; } static std::vector> gatherFirstUses( at::ArrayRef values) { return fmap(values, [&](Value* v) -> std::optional { return firstOrLastUse(v, true); }); } static std::vector sort_indexes(at::ArrayRef values) { // initialize original index locations std::vector idx(values.size()); std::iota(idx.begin(), idx.end(), 0); std::vector> first_uses = gatherFirstUses(values); // Sort values based on canonical ordering of their first usage std::sort(idx.begin(), idx.end(), [&first_uses](size_t i1, size_t i2) { // if neither has any uses, use original ordering. Since the // only values that jitter are ones added by the compiler and are guaranteed // to have uses, original ordering is fine. if (first_uses[i1] == std::nullopt && first_uses[i2] == std::nullopt) { return i1 < i2; } if (first_uses[i1] == std::nullopt) { return false; } else if (first_uses[i2] == std::nullopt) { return true; } auto fst_v1 = *first_uses[i1]; auto fst_v2 = *first_uses[i2]; return isBefore(fst_v1, fst_v2); }); return idx; } static void CanonicalizeLoopOutputs(Node* n) { auto new_indices = sort_indexes(n->outputs()); LoopView(n).permuteLoopCarried(new_indices); } static void CanonicalizeIfOutputs(Node* n) { auto new_indices = sort_indexes(n->outputs()); IfView(n).permuteOutputs(new_indices); } static void CanonicalizeOutputs(Block* block) { // We iterate in reverse since ordering of a node's outputs is dependent on // the value use following it in the graph for (Node* n : block->nodes().reverse()) { switch (n->kind()) { case prim::Loop: { CanonicalizeLoopOutputs(n); } break; case prim::If: { CanonicalizeIfOutputs(n); } break; } // Since an a control flow node's outputs are after // the values outputted within its blocks, first canonicalize // the nodes outputs and then recurse on its blocks for (Block* b : n->blocks()) { CanonicalizeOutputs(b); } } } // Canonicalize a graph's control flow node outputs. We do this to solve jitter // issues with outputs added to control flow nodes after the first pass of // compilation in ir_emitter.cpp void CanonicalizeOutputs(std::shared_ptr& graph) { CanonicalizeOutputs(graph->block()); } } // namespace torch::jit