#include #include #include #include #include #include #include #include namespace torch::jit { namespace { static constexpr int64_t kUnrollFactor = 8; static constexpr int64_t kMaxBodySize = 32; static constexpr int64_t kMaxBodyRepeats = 64; bool isTrueConstant(Value* val) { std::optional maybe_value = constant_as(val); return maybe_value && *maybe_value; } bool isForLoop(Node* node) { if (node->kind() != prim::Loop) return false; Value* start_cond = node->inputs().at(1); Value* continue_cond = node->blocks().at(0)->outputs().at(0); return isTrueConstant(start_cond) && isTrueConstant(continue_cond); } // Counts the size of this block, stopping and returning once reaches limit // instructions. int64_t limitedBlockSize(Block* body, int64_t limit) { auto it = body->nodes().begin(); auto end = body->nodes().end(); for (int64_t i = 0; i < limit; ++it) { for (Block* subblock : it->blocks()) { i += limitedBlockSize(subblock, limit - i); } if (!it->notExecutedOp()) { ++i; } if (it == end) { return i; } } return limit; } bool isSmallBlock(Block* body) { return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize; } // XXX: This function can only be called with a loop that is guaranteed to // execute EXACTLY ONCE. void inlineBody(Node* loop) { auto graph = loop->owningGraph(); auto body = loop->blocks().at(0); WithInsertPoint insert_point_guard{loop}; std::unordered_map value_map; auto get_value = [&](Value* v) { auto it = value_map.find(v); if (it != value_map.end()) return it->second; return v; }; // Loop node has extra (max_iters, initial_cond) inputs, // body has an extra (loop_counter) input. for (size_t i = 2; i < loop->inputs().size(); ++i) { value_map[body->inputs()[i - 1]] = loop->inputs()[i]; } for (Node* orig : body->nodes()) { Node* clone = graph->insertNode(graph->createClone(orig, get_value)); for (size_t i = 0; i < orig->outputs().size(); ++i) { value_map[orig->outputs()[i]] = clone->outputs()[i]; } } for (size_t i = 0; i < loop->outputs().size(); ++i) { loop->outputs().at(i)->replaceAllUsesWith( get_value(body->outputs().at(i + 1))); } // XXX: it is extremely important to destroy the loop in here. DCE might not // be able to conclude that it's safe, because the loop might contain side // effects. loop->destroy(); } // inserts a copy of body, passing inputs to the inputs of the block // it returns the a list of the Values for the output of the block std::vector insertBlockCopy( Graph& graph, Block* body, at::ArrayRef inputs) { TORCH_INTERNAL_ASSERT(inputs.size() == body->inputs().size()); std::unordered_map value_map; auto get_value = [&](Value* v) { auto it = value_map.find(v); if (it != value_map.end()) return it->second; return v; }; auto inputs_it = inputs.begin(); for (Value* input : body->inputs()) { value_map[input] = *inputs_it++; } for (Node* node : body->nodes()) { Node* new_node = graph.insertNode(graph.createClone(node, get_value)); auto outputs_it = new_node->outputs().begin(); for (Value* output : node->outputs()) { value_map[output] = *outputs_it++; } } return fmap(body->outputs(), get_value); } void repeatBody(Block* body, size_t times, Block* dest) { auto graph = body->owningGraph(); WithInsertPoint insert_point_guard(dest); for (Value* input : body->inputs()) { dest->addInput()->copyMetadata(input); } std::vector io = dest->inputs().vec(); TORCH_INTERNAL_ASSERT( !body->inputs().at(0)->hasUses(), "loop counter should be unused"); for ([[maybe_unused]] const auto i : c10::irange(times)) { io[0] = body->inputs().at(0); io = insertBlockCopy(*graph, body, io); } for (Value* output : io) { dest->registerOutput(output); } // It's likely that we have some dead nodes now - for example the "true" // constant that prevents the loop from breaking. We shouldn't wait too long // before removing them because they might artificially increase the loop size // and prevent outer loop unrolling. EliminateDeadCode(dest, false); } // Replaces the builtin loop counter with a "mutable" variable outside of the // loop. void replaceLoopCounter(Node* loop) { Graph* graph = loop->owningGraph(); Block* body = loop->blocks().at(0); WithInsertPoint guard(loop); Value* init_counter = graph->insertConstant(0); loop->insertInput(2, init_counter); loop->insertOutput(0)->setType(IntType::get()); Value* internal_counter = body->insertInput(1)->setType(init_counter->type()); body->inputs()[0]->replaceAllUsesWith(internal_counter); WithInsertPoint insertPointGuard{body->return_node()}; Value* result = graph->insert(aten::add, {internal_counter, 1}); body->insertOutput(1, result); } void unroll(Node* loop) { Graph* graph = loop->owningGraph(); Block* body = loop->blocks().at(0); // We will be using a "mutable" counter outside of the loop instead of the // default one, because this will allow us to share it between the unrolled // loop and its epilogue. This is necessary only if the loop counter is // actually used in the body. if (!body->inputs()[0]->uses().empty()) replaceLoopCounter(loop); // Some optimization for constant-length loops. If we know they won't run too // many times, then we can unroll them entirely. Value* trip_count = loop->inputs().at(0); std::optional const_len = constant_as(trip_count); if (const_len && *const_len < kMaxBodyRepeats) { Block* dest = loop->addBlock(); repeatBody(body, *const_len, dest); loop->eraseBlock(0); inlineBody(loop); return; } WithInsertPoint insert_point_guard{loop}; // Clone the loop before we unroll it. The clone will become the epilogue. Node* loop_epilogue = graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop); for (size_t i = 0; i < loop->outputs().size(); ++i) { loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]); loop_epilogue->replaceInput(i + 2, loop->outputs()[i]); } Block* dest = loop->addBlock(); repeatBody(body, kUnrollFactor, dest); loop->eraseBlock(0); // Change the iteration counts of both loops Value* iter_count = loop->inputs().at(0); Value* unrolled_iter_count = graph->insert( aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor}); loop->replaceInput(0, unrolled_iter_count); loop_epilogue->replaceInput( 0, graph->insert( aten::sub, {iter_count, graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})})); } bool UnrollLoops(Block* block, bool constant_only) { bool changed = false; for (auto it = block->nodes().begin(); it != block->nodes().end();) { // XXX: unroll might destroy the current node, so we need to pre-increment // the iterator Node* node = *it; ++it; for (Block* subblock : node->blocks()) { changed |= UnrollLoops(subblock, constant_only); } if (!isForLoop(node)) { continue; } if (constant_only) { if (node->inputs().at(0)->node()->kind() != prim::Constant) { continue; } } else if (!isSmallBlock(node->blocks().at(0))) { continue; } unroll(node); changed = true; } return changed; } } // anonymous namespace static void addCondAsOutput(Node* loop) { LoopView loop_view(loop); loop->addInput(loop_view.inputCond()); auto block_cond_input = loop_view.bodyBlock()->addInput(); block_cond_input->copyMetadata(loop_view.inputCond()); auto cond_output_index = loop_view.bodyBlock()->registerOutput(loop_view.nextCond()); loop_view.bodyBlock()->outputs()[cond_output_index]->copyMetadata( loop_view.nextCond()); auto cond_output = loop->addOutput(); cond_output->copyMetadata(loop_view.nextCond()); } bool LoopsPeeler::run(const std::shared_ptr& graph) { GRAPH_DUMP("Before LoopsPeeler", graph); collectLoops(graph->block()); peelLoops(); GRAPH_DUMP("After LoopsPeeler", graph); return true; } void LoopsPeeler::collectLoop(Node* n) { if (callback_(n)) { if (in_loop_) { GRAPH_DEBUG("Loop ", getHeader(in_loop_), " will be unrolled"); loops_to_peel_.push_back(in_loop_); in_loop_ = nullptr; } } } void LoopsPeeler::collectLoops(Block* block) { // we do a pre-order traversal to reduce the number // of peeled loops. for (auto n : block->nodes()) { collectLoop(n); } collectLoop(block->return_node()); // process child blocks for (auto n : block->nodes()) { auto old_in_loop_ = in_loop_; if (n->kind() == prim::Loop) { in_loop_ = n; } for (auto b : n->blocks()) { collectLoops(b); } in_loop_ = old_in_loop_; } } void LoopsPeeler::peelLoops() { for (auto loop : loops_to_peel_) { PeelLoop(loop, num_iterations_); } } bool PeelProfilingLoops(const std::shared_ptr& graph) { auto peel_predicate = [](Node* n) { for (auto i : n->inputs()) { if (i->type()->isSubtypeOf(*TensorType::get())) { return true; } } return false; }; LoopsPeeler lp(peel_predicate); return lp.run(graph); } Node* PeelLoop(Node* n, size_t times) { GRAPH_DEBUG("Peeling the loop ", getHeader(n), " ", times, " times"); auto graph = n->owningGraph(); auto orig_loop = LoopView(n); WithInsertPoint wip(n); auto times_const = graph->insertConstant(static_cast(times)); // N.B. even though a caller may request to peel `times` iterations // `maxTripCount` of the original loop might be less than that // so we should take the minimum of the two auto min_trip_count = graph->insert(prim::min, {orig_loop.maxTripCount(), times_const}); // make the peeled clone auto peeled_copy = graph->createClone(n, [](Value* v) { return v; }); addCondAsOutput(peeled_copy); LoopView new_lv(peeled_copy); graph->insertNode(peeled_copy); // only run until the peeled count new_lv.replaceMaxTripCount(min_trip_count); // subtract `maxTripCount` of the original loop by the number iterations // the peeled loop runs auto new_max_trip_count = graph->insert(aten::sub, {orig_loop.maxTripCount(), min_trip_count}); orig_loop.replaceMaxTripCount(new_max_trip_count); // update the termination condition auto cond_index = peeled_copy->outputs().size() - 1; orig_loop.replaceInputCondition(peeled_copy->output(cond_index)); static const size_t LOOP_DEPS_WITH_COND_OFFSET = 2; for (size_t i = 0; i < peeled_copy->outputs().size() - 1 /* leave off the termination condition */; i++) { n->replaceInput(LOOP_DEPS_WITH_COND_OFFSET + i, peeled_copy->output(i)); } // the induction variable also needs to be adjusted by the number of // iterations the peeled loop runs { WithInsertPoint peeled_wip(*orig_loop.bodyBlock()->nodes().begin()); // we can't create the expression: `new_counter` = `old_counter` + 1 yet // because when we // run `old_counter->replaceAllUsesWith(new_counter)`, we will get // `new_counter = new_counter + 1` auto adjusted_iter_counter = graph->insert(aten::add, {min_trip_count, min_trip_count}); orig_loop.currentTripCount()->replaceAllUsesWith(adjusted_iter_counter); adjusted_iter_counter->node()->replaceInput( 0, orig_loop.currentTripCount()); } return peeled_copy; } bool UnrollLoops(std::shared_ptr& graph) { bool changed = UnrollLoops(graph->block(), false); if (changed) { EliminateDeadCode(graph); } return changed; } bool UnrollConstantLoops(std::shared_ptr& graph) { bool changed = UnrollLoops(graph->block(), true); if (changed) { EliminateDeadCode(graph); } return changed; } } // namespace torch::jit