#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { namespace tensorexpr { LoopNest::LoopNest(const LoopNest& other) : root_stmt_(Stmt::clone(other.root_stmt_)), output_bufs_(other.output_bufs_) { verify(root_stmt_); } LoopNest::LoopNest(Stmt* stmt, std::unordered_set output_bufs) : root_stmt_(stmt), output_bufs_(std::move(output_bufs)) { verify(root_stmt_); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) LoopNest::LoopNest( const std::vector& output_tensors, const std::vector& tensors_to_compute) { initialize(output_tensors, tensors_to_compute); verify(root_stmt_); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) LoopNest::LoopNest(const std::vector& output_tensors) { initialize(output_tensors, output_tensors); verify(root_stmt_); } const std::unordered_set LoopNest::getIntermediateBufs() const { std::unordered_set result; auto input_bufs = getInputBufs(); auto bufs = NodeFinder::find(root_stmt_); for (auto* buf : bufs) { if (!output_bufs_.count(buf) && !input_bufs.count(buf)) { result.insert(buf); } } return result; } const std::unordered_set LoopNest::getInputBufs() const { std::unordered_set result; auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); for (const auto& kv : buf_load_store_uses) { bool has_store = false; for (const auto& use : kv.second) { if (use.isStore) { has_store = true; break; } } if (!has_store) { result.insert(kv.first); } } return result; } class IndexFlattener : public IRMutator { public: Stmt* flatten(Stmt* s) { return s->accept_mutator(this); } const Expr* mutate(const Load* v) override { if (v->indices().size() == 1) { return v; } return new Load( v->dtype(), v->buf(), {flatten_index(v->buf()->dims(), v->indices())}); } Stmt* mutate(const Store* v) override { const Expr* value = v->value(); const Expr* new_value = value->accept_mutator(this); if (v->indices().size() == 1 && value == new_value) { return (Stmt*)v; } return new Store( v->buf(), {flatten_index(v->buf()->dims(), v->indices())}, new_value); } }; class Vectorizer : public IRMutator { public: Stmt* vectorize(const For* v) { Stmt* body = v->body(); const Var* var = v->var(); const Expr* start = v->start(); const Expr* stop = v->stop(); const IntImm* start_imm = dynamic_cast(start); const IntImm* stop_imm = dynamic_cast(stop); if (!start_imm) { throw std::runtime_error( "Can't vectorize due to non-constant loop start!"); } if (!stop_imm) { throw std::runtime_error( "Can't vectorize due to non-constant loop stop!"); } var_ = var; start_ = start_imm; lanes_ = stop_imm->value(); Stmt* new_body = body->accept_mutator(this); if (new_body == body) { throw std::runtime_error("Vectorization failed!"); } return new_body; } const Expr* mutate(const Add* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) + ExprHandle(inputs[1]); }); } const Expr* mutate(const Sub* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) - ExprHandle(inputs[1]); }); } const Expr* mutate(const Mul* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) * ExprHandle(inputs[1]); }); } const Expr* mutate(const Div* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) / ExprHandle(inputs[1]); }); } const Expr* mutate(const And* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) & ExprHandle(inputs[1]); }); } const Expr* mutate(const Or* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) | ExprHandle(inputs[1]); }); } const Expr* mutate(const Xor* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) ^ ExprHandle(inputs[1]); }); } const Expr* mutate(const Lshift* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) << ExprHandle(inputs[1]); }); } const Expr* mutate(const Rshift* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) >> ExprHandle(inputs[1]); }); } const Expr* mutate(const Max* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return Max::make( ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans()); }); } const Expr* mutate(const Min* v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return Min::make( ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans()); }); } const Expr* mutate(const CompareSelect* v) override { std::vector inputs = { v->lhs(), v->rhs(), v->ret_val1(), v->ret_val2()}; return try_vectorize(v, inputs, [&]() { return CompareSelect::make( ExprHandle(inputs[0]), ExprHandle(inputs[1]), ExprHandle(inputs[2]), ExprHandle(inputs[3]), v->compare_select_op(), v->bias()); }); } const Expr* mutate(const BitCast* v) override { std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { return BitCast::make( Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); }); } const Expr* mutate(const Cast* v) override { std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { return Cast::make( Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); }); } const Expr* mutate(const Var* v) override { if (v == var_) { return Ramp::make(ExprHandle(start_), 1, lanes_).node(); } return v; } const Expr* mutate(const Ramp* v) override { const Expr* base = v->base(); const Expr* stride = v->stride(); const Expr* base_new = base->accept_mutator(this); const Expr* stride_new = stride->accept_mutator(this); if (base_new == base && stride_new == stride) { return v; } throw std::runtime_error("Can't vectorize a Ramp!"); } const Expr* mutate(const Load* v) override { Dtype dtype(v->dtype().scalar_type(), lanes_); const Buf* buf = v->buf(); std::vector inputs = {v->flat_index()}; return try_vectorize(v, inputs, [&]() { return Load::make(dtype, BufHandle(buf), {ExprHandle(inputs[0])}); }); } const Expr* mutate(const ReduceOp* v) override { Dtype dtype(v->dtype().scalar_type(), lanes_); std::vector inputs = {v->body()}; auto* out = try_vectorize(v, inputs, [&]() { return ExprHandle( new ReduceOp(inputs[0], v->reduce_args(), v->reducer())); }); return out; } const Expr* mutate(const Broadcast* v) override { const Expr* val = v->value(); const Expr* new_val = val->accept_mutator(this); if (new_val == val) { return v; } throw std::runtime_error("Can't vectorize a Broadcast!"); } const Expr* mutate(const IfThenElse* v) override { const Expr* condition = v->condition(); const Expr* new_condition = condition->accept_mutator(this); if (new_condition != condition) { throw std::runtime_error("Can't vectorize an IfThenElse condition!"); } std::vector inputs = {v->true_value(), v->false_value()}; return try_vectorize(v, inputs, [&]() { return IfThenElse::make( ExprHandle(condition), ExprHandle(inputs[0]), ExprHandle(inputs[1])); }); } const Expr* mutate(const Intrinsics* v) override { std::vector inputs = v->params(); return try_vectorize(v, inputs, [&]() { return ExprHandle(new Intrinsics(v->op_type(), inputs)); }); } Stmt* mutate(const Store* v) override { const Buf* buf = v->buf(); std::vector inputs = {v->flat_index(), v->value()}; return try_vectorize(v, inputs, [&]() { return Store::make( BufHandle(buf), {ExprHandle(inputs[0])}, ExprHandle(inputs[1])); }); } Stmt* mutate(const For* v) override { const Var* var = v->var(); const Expr* start = v->start(); const Expr* stop = v->stop(); LoopOptions loop_options = v->loop_options(); const Expr* new_start = start->accept_mutator(this); const Expr* new_stop = stop->accept_mutator(this); if (new_start != start || new_stop != stop) { throw std::runtime_error( "Can't vectorize nested For with dependent loop bounds!"); } Stmt* body = v->body(); Stmt* new_body = body->accept_mutator(this); if (new_body == body) { return (For*)v; } return new For(var, new_start, new_stop, new_body, loop_options); } template const Expr* try_vectorize( const Expr* e, std::vector& inputs, T&& vec_ctor) { bool vectorize = vectorize_inputs(inputs); if (vectorize) { return vec_ctor().node(); } return e; } template Stmt* try_vectorize( const Stmt* s, std::vector& inputs, T&& vec_ctor) { bool vectorize = vectorize_inputs(inputs); if (vectorize) { return vec_ctor(); } return (Stmt*)s; } bool vectorize_inputs(std::vector& inputs) { bool any_vectorized = false; std::vector new_inputs; // Attempt to vectorize each input. for (const Expr*& in : inputs) { const Expr* new_in = in->accept_mutator(this); new_inputs.push_back(new_in); if (new_in != in) { any_vectorized = true; } } // If none of them vectorized, then don't vectorize this. if (!any_vectorized) { return false; } // Insert broadcasts for any inputs that weren't vectorized. for (size_t i = 0; i < inputs.size(); ++i) { if (inputs[i] == new_inputs[i]) { inputs[i] = Broadcast::make(ExprHandle(inputs[i]), lanes_).node(); } else { inputs[i] = new_inputs[i]; } } // And then vectorize this node. return true; } const Var* var_ = nullptr; int lanes_ = 0; const Expr* start_ = nullptr; }; void LoopNest::vectorize(For* f) { Block* b = dynamic_cast(f->get_parent()); if (!b) { return; } // Can't vectorize reduction axes. auto reductions = NodeFinder::find(f); for (auto* r : reductions) { if (std::find(r->reduce_args().begin(), r->reduce_args().end(), f->var()) != r->reduce_args().end()) { throw std::logic_error("Cannot vectorize reduction axis - rfactor first"); } } Vectorizer v; Stmt* old_f = Stmt::clone(f); Stmt* new_f = nullptr; try { new_f = FlattenIndexes(f); new_f = v.vectorize(dynamic_cast(new_f)); } catch (std::runtime_error& e) { // Partial vectorization may have corrupted f new_f = old_f; } b->replace_stmt(f, IRSimplifier::simplify(new_f)); } void LoopNest::initialize( const std::vector& output_tensors, const std::vector& tensors_to_compute) { for (auto t : output_tensors) { output_bufs_.insert(t->buf()); } std::vector loops; for (Tensor* t : tensors_to_compute) { Stmt* loop = t->stmt(); if (loop->get_parent()) { std::cerr << "Error: creating a loopnest from already used Tensors\n"; loops = {}; break; } // Flatten initializers. if (Block* block = dynamic_cast(loop)) { for (auto* s : block->stmts()) { block->remove_stmt(s); loops.push_back(s); } } else { loops.push_back(loop); } } root_stmt_ = new Block(loops); } class FunctionInliner : public IRMutator { public: FunctionInliner(Store* producer, std::unordered_set outputs) : buf_(producer->buf()), producer_(producer), outputs_(std::move(outputs)) { for (auto* i : producer->indices()) { if (auto index_var = dynamic_cast(i)) { index_vars_.insert(index_var); producer_index_vars_.push_back(index_var); } else if (dynamic_cast(i) != nullptr) { // If the index can be a constant, then that dimension must have size 1 // (since we don't support in-place writes). Resolves issue 52581. TORCH_INTERNAL_ASSERT( dynamic_cast(i)->value() == 0, "Constant index impression should always be zero"); producer_index_vars_.push_back(nullptr); } else { throw std::logic_error("cannot inline Buf with compound indices"); } } } private: const Expr* mutate_loads(const Buf* buf, std::vector dims) { std::vector index_vars; TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size()); for (const auto i : c10::irange(buf->ndim())) { const Var* func_callee_arg = producer_index_vars_.at(i); const Expr* func_caller_param = dims.at(i); if (func_callee_arg == nullptr) { TORCH_INTERNAL_ASSERT( dynamic_cast(func_caller_param) != nullptr && dynamic_cast(func_caller_param)->value() == 0, "We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0"); continue; } if (func_callee_arg == nullptr) continue; auto iter = inline_mapping_.find(func_callee_arg); if (iter != inline_mapping_.end()) { throw std::runtime_error( "Duplicated variables: " + func_callee_arg->name_hint()); } // Add a mapping for each function parameter to it's source name. inline_mapping_[func_callee_arg] = func_caller_param; index_vars.push_back(func_callee_arg); } // Call the actual replacement. const Expr* body = producer_->value(); const Expr* result = body->accept_mutator(this); // Remove the mappings we created for this function parameters. for (auto* v : index_vars) { for (auto& pair : random_bindings_) { if (pair.second.erase(v)) { const Expr* inlined = inline_mapping_[v]; for (auto* nv : VarFinder::find(inlined)) { pair.second.insert(nv); } } } inline_mapping_.erase(v); } return result; } const Expr* mutate(const Load* v) override { const Buf* buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } if (v->indices().size() != buf->ndim()) { throw malformed_input( "Placeholder indexed access is inconsistent with its rank", v); } return mutate_loads(buf, v->indices()); } // Replace the target variable with the caller expressions. const Expr* mutate(const Var* v) override { auto iter = inline_mapping_.find(v); if (iter == inline_mapping_.end()) { return v; } else { const Expr* expr = iter->second; // Continue to transform the value from the lookup table. return expr->accept_mutator(this); } } // Handle random intrinsics which should be cached. const Expr* mutate(const Intrinsics* v) override { if (!in_producer_ || v->op_type() != kRand) { return IRMutator::mutate(v); } // Create a new Let Statment for the random variable, which we can refer to // multiple times and resolve the same value (ie. store it in a scalar // rather than the Tensor). const std::string& name = buf_->name_hint(); Var* new_var = new Var(name, v->dtype()); random_bindings_[new Let(new_var, v)] = index_vars_; return new_var; } // Remove the buffer write from the inlined function. Stmt* mutate(const Store* v) override { // If the buf_ is in the outputs set, keep its statement intact. Otherwise, // remove it. if (v == producer_ && !outputs_.count(buf_)) { in_producer_ = true; producer_ = dynamic_cast(IRMutator::mutate(v)); TORCH_INTERNAL_ASSERT(producer_ != nullptr); in_producer_ = false; return nullptr; } else { return IRMutator::mutate(v); } } // Any Random Instrinsics that were turned into vars must be inserted here. Stmt* mutate(const Block* v) override { std::vector stmts; for (Stmt* stmt : *v) { Stmt* stmt_new = stmt->accept_mutator(this); if (!stmt_new) { continue; } if (stmt == stmt_new) { stmt_new = Stmt::clone(stmt); } stmts.push_back(stmt_new); } return Block::make(stmts); } Stmt* mutate(const For* v) override { For* res = dynamic_cast(IRMutator::mutate(v)); if (!res) { return nullptr; } // Find any random bindings that should be defined in this loops body. std::vector bindings_this_loop; const Var* fv = v->var(); for (auto& pair : random_bindings_) { auto& index_var = pair.second; if (index_var.erase(fv)) { bindings_this_loop.push_back(pair.first); } } for (auto* l : bindings_this_loop) { res->body()->prepend_stmt(l); random_bindings_.erase(l); } return res; } private: const Buf* buf_; const Store* producer_; // Index Vars present in the producer. std::unordered_set index_vars_; std::vector producer_index_vars_; std::unordered_map inline_mapping_; // In the producer's scope - we need to bind any calls to rand(). bool in_producer_ = false; std::unordered_map> random_bindings_; std::unordered_set outputs_; }; bool LoopNest::computeInline(Stmt* s) { auto* s_store = dynamic_cast(s); if (s_store == nullptr) { throw std::logic_error("Could not find buffer producer to inline"); } return computeInline(s_store->buf()); } bool LoopNest::computeInline(const Buf* b) { // If buf is used or defined in an ExternalCall, we cannot inline it auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); for (const auto& use : buf_load_store_uses.at(b)) { Stmt* s = use.s; if (dynamic_cast(s)) { return false; } } // Find producers. Store* relevant_store{nullptr}; auto stores = NodeFinder::find(root_stmt_); for (auto* s : stores) { if (s->buf() == b) { auto reductions = NodeFinder::find(s); if (!reductions.empty()) { // Cannot inline a reduction computation return false; } if (relevant_store != nullptr) { // Cannot inline Buf with multiple Tensors return false; } relevant_store = s; } } TORCH_INTERNAL_ASSERT(relevant_store); FunctionInliner inliner(relevant_store, output_bufs_); root_stmt_ = root_stmt_->accept_mutator(&inliner); return true; } // inlining buffers with multiple uses can create duplicated work, which can // slow down cpu code generation but is enabled on gpu because it avoids // difficult synchronization logic across blocks. Inlining trivial reads does // not duplicate work void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { std::unordered_set bufs_to_inline; auto intermediate_bufs = getIntermediateBufs(); if (allow_duplicated_work) { bufs_to_inline.insert(intermediate_bufs.begin(), intermediate_bufs.end()); } else { auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); auto input_bufs = getInputBufs(); for (auto buf : intermediate_bufs) { TORCH_INTERNAL_ASSERT(buf_load_store_uses.count(buf)); std::vector& uses = buf_load_store_uses[buf]; auto stores = c10::filter( uses, [](const BufLoadOrStoreUse& use) { return use.isStore; }); // if the intermediate is the buffer formed from reading in the input // tensors, always inline, bc we are not duplicating any work // and avoiding an intermediary buffer if (stores.size() == 1) { if (auto store = dynamic_cast(stores[0].s)) { auto input_as_load = dynamic_cast(store->value()); if (input_as_load && input_bufs.count(input_as_load->buf())) { bufs_to_inline.insert(buf); continue; } } else { // If S is not a store, it must be an ExternalCall. TORCH_INTERNAL_ASSERT(dynamic_cast(stores[0].s)); } } // all bufs will have at least one store (if they have > 1 they cant be // inlined anyway) size_t reads = uses.size() - 1; // if only one read, we can inline it without duplicating work if (reads <= 1) { bufs_to_inline.insert(buf); } } } if (allow_duplicated_work) { bufs_to_inline.insert(output_bufs_.begin(), output_bufs_.end()); } for (auto b : bufs_to_inline) { computeInline(b); } } // TODO: Unify with DepTracker class LoadOrStoreUseFinder : public IRVisitor { public: std::unordered_map> findUses( Stmt* s) { uses_.clear(); s->accept(this); return uses_; } private: void visit(const Store* v) override { if (stores_[v->buf()].insert(last_stmt_).second) { uses_[v->buf()].push_back({(Stmt*)v, true}); } last_stmt_ = (Stmt*)v; IRVisitor::visit(v); } void visit(const ExternalCall* v) override { if (stores_[v->buf()].insert(last_stmt_).second) { uses_[v->buf()].push_back({(Stmt*)v, true}); } last_stmt_ = (Stmt*)v; for (const Buf* input_buf : v->buf_args()) { if (loads_[input_buf].insert(last_stmt_).second) { uses_[input_buf].push_back({last_stmt_, false}); } } IRVisitor::visit(v); } void visit(const Load* v) override { if (loads_[v->buf()].insert(last_stmt_).second) { uses_[v->buf()].push_back({last_stmt_, false}); } IRVisitor::visit(v); } Stmt* last_stmt_ = nullptr; std::unordered_map> uses_; // Sets of loads and stores in order to keep the results unique std::unordered_map> loads_; std::unordered_map> stores_; }; std::unordered_map> findLoadOrStoreUses(Stmt* s) { LoadOrStoreUseFinder uf; return uf.findUses(s); } class ContainedStmtsFinder : public IRVisitor { public: // Simply list all Stores and Block that are children of the given stmt const std::unordered_set& findContainedStmts(Stmt* s) { contained_.clear(); s->accept(this); return contained_; } private: void visit(const Store* v) override { contained_.insert((Stmt*)v); IRVisitor::visit(v); } void visit(const ExternalCall* v) override { contained_.insert((Stmt*)v); IRVisitor::visit(v); } void visit(const Block* v) override { contained_.insert((Stmt*)v); IRVisitor::visit(v); } std::unordered_set contained_; }; bool containsAll(const std::vector& uses, Block* b) { std::unordered_set not_found; for (auto use : uses) { not_found.insert(use.s); } ContainedStmtsFinder csf; const std::unordered_set& contained = csf.findContainedStmts(b); for (auto s : contained) { not_found.erase(s); } return not_found.empty(); } Block* findParentBlock(Stmt* s) { while (s) { if (auto b = dynamic_cast(s)) { return b; } s = s->get_parent(); } return nullptr; } Block* findLowestContainingBlock(const std::vector& uses) { // TODO: we're not using the most efficient algorithm here for simplicity. // Replace with something more performant in case it becomes a bottleneck. Block* b = findParentBlock(uses[0].s); while (b && !containsAll(uses, b)) { b = findParentBlock(b->get_parent()); } return b; } Stmt* LoopNest::insertAllocFree(Stmt* stmt) { auto intermediate_bufs = getIntermediateBufs(); if (intermediate_bufs.size() == 0ULL) { return stmt; } Block* b = dynamic_cast(stmt); if (!b) { b = new Block({stmt}); } std::unordered_map> uses = findLoadOrStoreUses(stmt); // Insert allocations and frees for temporary buffers in the innermost // possible scope. for (const Buf* buf : intermediate_bufs) { Stmt* alloc = new Allocate(buf); Stmt* free = new Free(buf); Block* alloc_block = findLowestContainingBlock(uses.at(buf)); alloc_block->prepend_stmt(alloc); alloc_block->append_stmt(free); } return b; } class StmtDeleter : public IRMutator { public: StmtDeleter(const std::unordered_set& targets) : targets_(targets) {} private: Stmt* mutate(const Block* v) override { std::vector stmts; for (auto* s : v->stmts()) { if (targets_.count(s) == 0) { Stmt* ns = s->accept_mutator(this); if (ns) { stmts.push_back(Stmt::clone(ns)); } } } return Block::make(stmts); } const std::unordered_set& targets_; }; void LoopNest::eliminateDeadStores() { using namespace analysis; MemDependencyChecker checker(getInputBufs(), getOutputBufs()); root_stmt_->accept(&checker); std::unordered_set deadStores; std::vector> outputAccesses; for (auto* o : getOutputBufs()) { outputAccesses.push_back(checker.output(o)); } for (auto& info : checker.getHistory()) { if (!info->isWrite()) { continue; } bool found = false; for (auto& output : outputAccesses) { if (checker.dependsIndirectly(output, info)) { found = true; break; } } if (!found) { deadStores.insert(info->stmt()); } } StmtDeleter deleter(deadStores); root_stmt_ = root_stmt_->accept_mutator(&deleter); } void LoopNest::prepareForCodegen() { // Expand reduction ops. ReductionExpander reduceExpander; root_stmt_ = reduceExpander.expand(root_stmt_); root_stmt_ = FlattenIndexes(root_stmt_); // Add allocs and frees for intermediate buffers at the global level. root_stmt_ = insertAllocFree(root_stmt_); } namespace { class IfThenElseReplacer : public IRMutator { public: IfThenElseReplacer(const IfThenElse* to_replace, const Expr* new_expr) : to_replace_(to_replace), new_expr_(new_expr) {} const Expr* mutate(const IfThenElse* i) override { if (i == to_replace_) { return new_expr_; } return i; } private: const IfThenElse* to_replace_; const Expr* new_expr_; }; // Check if the given condition is optimizable. // Specifically, this function looks for the following pattern: // "var < expr" // // If this pattern is found, then this function: // * sets `cond_var` to `var`, // * sets `compared_value` to `expr`, and // * returns true. bool isConditionOptimizable( const Expr* condition, const Var** cond_var, const Expr** compared_value) { auto cs = dynamic_cast(condition); if (cs && cs->compare_select_op() == kLT) { auto var = dynamic_cast(cs->lhs()); if (var) { *cond_var = var; *compared_value = cs->rhs(); return true; } } return false; } // Checks if the given if-then-else expression is a conditional that is // generated from `aten::cat`. // // The expected format of conditionals is: // IfThenElse(var < val1? 1 : 0, // IfThenElse (var < val2? 1 : 0, // IfThenElse (var < val3? 1 : 0, // sub-expr1, // sub-expr2), // sub-expr3), // sub-expr4) // // If such a conditional is found, this function also sets: // * cond_var to the condition variable found in this expression. // * comp_values to the list of compared values in the condition expressions. // * sub_exprs to the list of sub-expressions that are the result of this // if-then-else expression. bool isConditionalFromCat( const IfThenElse* ite, const Var** cond_var, std::vector* comp_values, std::vector* sub_exprs) { const Var* var = nullptr; const Expr* comp_value; if (isConditionOptimizable(ite->condition(), &var, &comp_value)) { if (*cond_var == nullptr) { *cond_var = var; } else if (*cond_var != var) { // Different condition variables found in nested if-then-else // expressions. Can not optimize such cases. return false; } auto true_ite = dynamic_cast(ite->true_value()); if (true_ite) { if (!isConditionalFromCat(true_ite, cond_var, comp_values, sub_exprs)) { return false; } } else { sub_exprs->push_back(ite->true_value()); } auto false_ite = dynamic_cast(ite->false_value()); if (false_ite) { return false; } comp_values->push_back(comp_value); sub_exprs->push_back(ite->false_value()); return true; } return false; } bool areConstantsAndSorted(const std::vector& comp_values) { std::vector comp_consts; comp_consts.reserve(comp_values.size()); for (auto c : comp_values) { if (!c->isConstant()) { return false; } comp_consts.push_back(immediateAs(c)); } return std::is_sorted(comp_consts.begin(), comp_consts.end()); } } // namespace bool LoopNest::optimizeConditionals() { // Consider every store in the root_stmt_ and try to optimize the // conditionals in that store. auto stores = NodeFinder::find(root_stmt_); std::unordered_set split_fors; for (auto store : stores) { const Var* cond_var = nullptr; // `comp_values` represent the list of compared values that will be // collected as we check for the expected pattern. Since that will // only include the RHS of the conditions in the if-then-else expressions // we need to start with `0` which is the initial bound, given that we // only handle normalized loops (check for this is done below). std::vector comp_values = {new IntImm(0)}; std::vector sub_exprs; auto ifthenelse_exprs = NodeFinder::find(store); if (ifthenelse_exprs.empty()) { continue; } // We only check if the first if-then-else expression in this store // corresponds to a conditional of the required format. If there are more // than one such conditional, optimizing them requires checking if the // conditions are exactly the same across them and handling all of them // together. Currently, this is not handled. if (!isConditionalFromCat( ifthenelse_exprs.front(), &cond_var, &comp_values, &sub_exprs)) { continue; } auto fors = getLoopStmtsFor(store); if (cond_var != fors.back()->var()) { // Currently, we only handle the case where the condition variable // is the same as the inner-most loop variable. // TODO: Handle all other cases here. // // In order to handle all other cases, the method `clone_and_replace` // called below to clone the body of the loop with a new store needs // to recursively handle cloning of the loops and other blocks it // contains. continue; } auto for_to_split = fors.back(); if (!LoopNest::isNormalized(for_to_split)) { // Do not optimize this conditional since the condition variable // refers to a loop that is not normalized. continue; } if (split_fors.count(for_to_split)) { // This loop has already been split while optimizing conditionals // earlier. // // Optimizing multiple conditionals that require splitting the same loop // is tricky. It requires checking if the conditions are exactly the same // across them and handling all of them together by splitting the loop // exactly once. // // Currently, this case is not supported. continue; } split_fors.insert(for_to_split); // `comp_values` needs to include the end bound, which is `for_to_split` // stop value. comp_values.push_back(for_to_split->stop()); // Check if all `comp_values` are constants and they are sorted. if (!areConstantsAndSorted(comp_values)) { continue; } // Remove all the if-then-else expressions from this store and create // one loop per sub-expression. std::vector split_loops; auto cond_to_replace = ifthenelse_exprs.front(); for (size_t i = 0; i < sub_exprs.size(); ++i) { IfThenElseReplacer ifthenelseReplacer(cond_to_replace, sub_exprs[i]); auto new_store = store->accept_mutator(&ifthenelseReplacer); auto new_for_body = for_to_split->body()->clone_and_replace(store, new_store); auto new_for = new For( for_to_split->var(), comp_values[i], comp_values[i + 1], new_for_body); LoopNest::normalize(new_for); split_loops.push_back(new_for); } auto par = dynamic_cast(for_to_split->get_parent()); par->replace_stmt(for_to_split, new Block(split_loops)); } root_stmt_ = IRSimplifier::simplify(root_stmt_); return true; } void LoopNest::vectorizeInnerLoops() { std::vector innerLoops; std::vector worklist; // Find outer-most For loops if (For* rootF = dynamic_cast(root_stmt_)) { worklist.push_back(rootF); } else if (Block* body = dynamic_cast(root_stmt_)) { std::vector blocks = {body}; while (blocks.size()) { Block* b = blocks.back(); blocks.pop_back(); for (Stmt* s : *b) { if (For* f = dynamic_cast(s)) { worklist.push_back(f); } else if (Block* b2 = dynamic_cast(s)) { blocks.push_back(b2); } } } } // Traverse the For loop nest find inner-most loops, which are // vectorization candidates. while (worklist.size()) { For* f = worklist.back(); worklist.pop_back(); bool containsSubLoops = false; if (Block* body = dynamic_cast(f->body())) { for (Stmt* s2 : *body) { if (For* f2 = dynamic_cast(s2)) { containsSubLoops = true; worklist.push_back(f2); } } } if (!containsSubLoops) { innerLoops.push_back(f); } } // vectorize inner loops. for (For* loop : innerLoops) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* split1; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* tail1; static const int kBodyVectorWidth = 8; splitWithTail(loop, kBodyVectorWidth, &split1, &tail1); vectorize(split1); if (tail1) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* split2; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* tail2; static const int kTailVectorWidth = 4; splitWithTail(tail1, kTailVectorWidth, &split2, &tail2); vectorize(split2); } } } void LoopNest::sliceHead(For* f, int factor, For** head, For** tail) { if (dynamic_cast(f->start()) && dynamic_cast(f->stop())) { int start_val = dynamic_cast(f->start())->value(); int stop_val = dynamic_cast(f->stop())->value(); int size_val = stop_val - start_val; if (factor >= size_val) { *head = f; *tail = nullptr; return; } } if (!f) { throw malformed_input("sliceHead attempted on null loop", f); } Block* p = dynamic_cast(f->get_parent()); if (!p) { throw malformed_input("sliceHead attempted on loop with no parent", p); } const Expr* head_end = new Min(new Add(f->start(), new IntImm(factor)), f->stop(), true); *head = new For(f->var(), f->start(), head_end, Stmt::clone(f->body())); *tail = new For( f->var(), head_end, f->stop(), Stmt::clone(f->body()), f->loop_options()); p->replace_stmt(f, *head); p->insert_stmt_after(*tail, *head); if (f->loop_options().is_gpu_block_index() || f->loop_options().is_gpu_thread_index()) { LoopNest::normalize(*tail); } } void LoopNest::sliceHead(For* f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For *head, *tail; sliceHead(f, factor, &head, &tail); } void LoopNest::sliceTail(For* f, int factor, For** head, For** tail) { if (dynamic_cast(f->start()) && dynamic_cast(f->stop())) { int start_val = dynamic_cast(f->start())->value(); int stop_val = dynamic_cast(f->stop())->value(); int size_val = stop_val - start_val; if (factor >= size_val) { *head = nullptr; *tail = f; return; } } if (!f) { throw malformed_input("sliceTail attempted on null loop", f); } Block* p = dynamic_cast(f->get_parent()); if (!p) { throw malformed_input("sliceTail attempted on loop with no parent", p); } const Expr* tail_start = new Max(f->start(), new Sub(f->stop(), new IntImm(factor)), true); *head = new For( f->var(), f->start(), tail_start, Stmt::clone(f->body()), f->loop_options()); *tail = new For(f->var(), tail_start, f->stop(), Stmt::clone(f->body())); p->replace_stmt(f, *head); p->insert_stmt_after(*tail, *head); if (f->loop_options().is_gpu_block_index() || f->loop_options().is_gpu_thread_index()) { LoopNest::normalize(*head); } } void LoopNest::sliceTail(For* f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For *head, *tail; sliceTail(f, factor, &head, &tail); } void LoopNest::splitWithTail(For* f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For *inner, *tail; splitWithTail(f, factor, &inner, &tail); } void LoopNest::splitWithTail(For* f, int factor, For** inner, For** tail) { if (!f) { throw malformed_input("splitWithTail attempted on null loop", f); } Block* p = dynamic_cast(f->get_parent()); if (!p) { throw malformed_input("splitWithTail attempted on loop with no parent", p); } bool tail_is_needed = true; if (dynamic_cast(f->start()) && dynamic_cast(f->stop())) { int start_val = dynamic_cast(f->start())->value(); int stop_val = dynamic_cast(f->stop())->value(); int size_val = stop_val - start_val; int tail_size = size_val % factor; if (tail_size == 0) { tail_is_needed = false; } } const IntImm* factor_expr = new IntImm(factor); const Expr* size = new Sub(f->stop(), f->start()); const Expr* split_count = new Div(size, factor_expr); const Expr* tail_size = new Mod(size, factor_expr); const std::string& loop_var_name = f->var()->name_hint(); Dtype loop_var_dtype = f->var()->dtype(); const Var* i_inner = new Var(loop_var_name + "_inner", loop_var_dtype); const Var* i_outer = new Var(loop_var_name + "_outer", loop_var_dtype); // x -> x.outer * inner.size + x.inner const Expr* combined_index1 = new Add(new Mul(i_outer, factor_expr), i_inner); if (tail_is_needed) { const Var* i_tail = new Var(loop_var_name + "_tail", loop_var_dtype); // x -> x.tail + outer.size * inner.size const Expr* combined_index2 = new Add(i_tail, new Mul(split_count, factor_expr)); Stmt* body_tail = Substitute(Stmt::clone(f->body()), {{f->var(), combined_index2}}); *tail = new For(i_tail, new IntImm(0), tail_size, body_tail); p->insert_stmt_after(*tail, f); } else { *tail = nullptr; } Stmt* body_inner = Substitute(f->removeBody(), {{f->var(), combined_index1}}); *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->setVar(i_outer); f->setStart(new IntImm(0)); f->setStop(split_count); f->setBody(*inner); } void LoopNest::splitWithMask(For* f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* inner; splitWithMask(f, factor, &inner); } void LoopNest::splitWithMask(For* f, int factor, For** inner) { Block* p = dynamic_cast(f->get_parent()); if (!p) { std::cerr << "Parent is not a Block!\n"; return; } bool tail_is_needed = true; const Expr* start = IRSimplifier::simplify(f->start()); const Expr* stop = IRSimplifier::simplify(f->stop()); if (start->isConstant() && stop->isConstant()) { int start_val = immediateAs(start); int stop_val = immediateAs(stop); int size_val = stop_val - start_val; int tail_size = size_val % factor; if (tail_size == 0) { tail_is_needed = false; } } const IntImm* factor_expr = new IntImm(factor); const Expr* size = new Sub(f->stop(), f->start()); // split_count = (size + factor - 1) / factor const Expr* split_count = new Div(new Sub(new Add(size, factor_expr), new IntImm(1)), factor_expr); const std::string& loop_var_name = f->var()->name_hint(); Dtype loop_var_dtype = f->var()->dtype(); const Var* i_inner = new Var(loop_var_name + "_inner", loop_var_dtype); const Var* i_outer = new Var(loop_var_name + "_outer", loop_var_dtype); // x -> x.outer * inner.size + x.inner const Expr* combined_index = new Add(new Mul(i_outer, factor_expr), i_inner); Stmt* body_inner = f->removeBody(); // TODO: is it ok that we're doing it eagerly? In the other implementation we // are only materializing predicates at the last, lowering, step. if (tail_is_needed) { const IntImm* start = dynamic_cast(f->start()); if (!start || start->value() != 0) { throw unimplemented_lowering(); } const Expr* predicate = CompareSelect::make(ExprHandle(f->var()), ExprHandle(f->stop()), kLT) .node(); body_inner = Cond::make(ExprHandle(predicate), body_inner, nullptr); } body_inner = Substitute(body_inner, {{f->var(), combined_index}}); *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->setVar(i_outer); f->setStart(new IntImm(0)); f->setStop(split_count); f->setBody(*inner); } std::vector LoopNest::distributeLoop( For* loop, const std::unordered_set& pivots) { TORCH_INTERNAL_ASSERT(loop); auto root = loop->get_parent(); if (root == nullptr) { throw malformed_input("Loop without parent: ", loop); } auto root_block = dynamic_cast(root); if (root_block == nullptr) { throw malformed_input( "Loop's parent must be a Block, instead found ", root); } // Extract bodies for all the loops after distribution. std::vector new_loop_bodies; auto new_loop_body = new Block({}); while (!loop->body()->empty()) { auto s = loop->body()->front(); loop->body()->remove_stmt(s); new_loop_body->append_stmt(s); if (pivots.count(s)) { new_loop_bodies.push_back(new_loop_body); new_loop_body = new Block({}); } } if (!new_loop_body->empty()) { new_loop_bodies.push_back(new_loop_body); } // The first loop body has to be in the original loop. loop->body()->splice(loop->body()->begin(), new_loop_bodies.front()); std::vector new_loops = {loop}; // Create loops for all the remaining blocks. // Add all the new loops to the parent block. for (size_t i = 1; i < new_loop_bodies.size(); ++i) { auto new_loop = loop->cloneWithNewBody(new_loop_bodies[i]); root_block->insert_stmt_after(new_loop, new_loops.back()); new_loops.push_back(new_loop); } return new_loops; } std::vector LoopNest::distributeLoop(For* loop) { std::unordered_set stmtsInBlock( loop->body()->begin(), loop->body()->end()); return distributeLoop(loop, stmtsInBlock); } std::vector LoopNest::distributeLoopOverInnerLoops(For* loop) { auto loops = NodeFinder::find(loop); std::unordered_set loopsSet(loops.begin(), loops.end()); return distributeLoop(loop, loopsSet); } bool areEqual(const Expr* expr1, const Expr* expr2) { auto diff = IRSimplifier::simplify(new Sub(expr1, expr2)); return diff->isConstant() && (immediateAs(diff) == 0); }; bool areEqual( const std::vector& expr_list1, const std::vector& expr_list2) { if (expr_list1.size() != expr_list2.size()) { return false; } for (size_t i = 0; i < expr_list1.size(); ++i) { if (!areEqual(expr_list1[i], expr_list2[i])) { return false; } } return true; } bool LoopNest::hasLoopCarriedDependence(For* loop) { analysis::MemDependencyChecker analyzer; loop->accept(&analyzer); // High-level algorithm to check if two accesses to a buffer, A and B, one of // which is a Store, result in a loop-carried dependence: // 1. If the index expressions are equal in A and B, then that is a // loop-independent dependence. // 2. If the index expressions are not equal in A and B: // a) if the bounds on the accesses overlap, then this is a // loop-carried dependence. // b) if the bounds on the accesses do not overlap, then there is no // dependence. // // Implementation: // For every pair of statements, S1 and S2, in the loop: // * Get the loads and stores in S1 and S2. // * For every store in S1 and load in S2 to the same buffer, if the index // expressions are not equal and there is an overlap in accesses, return // true to indicate a loop-carried dependence. // * For every load in S1 and store in S2 to the same buffer, if the index // expressions are not equal and there is an overlap in accesses, return // true to indicate a loop-carried dependence. // * For every store in S1 and store in S2 to the same buffer, if the index // expressions are not equal and there is an overlap in accesses, return // true to indicate a loop-carried dependence. for (auto it1 = loop->body()->begin(); it1 != loop->body()->end(); ++it1) { for (auto it2 = std::next(it1); it2 != loop->body()->end(); ++it2) { auto aStores = NodeFinder::find(*it1); auto aLoads = NodeFinder::find(*it1); auto bStores = NodeFinder::find(*it2); auto bLoads = NodeFinder::find(*it2); // ReadAfterWrite for (auto& aStore : aStores) { for (auto& bLoad : bLoads) { if (aStore->buf() == bLoad->buf()) { if (!areEqual(aStore->indices(), bLoad->indices())) { if (isOverlapping(analyzer, aStore, bLoad)) { return true; } } } } } // WriteAfterRead for (auto& bStore : bStores) { for (auto& aLoad : aLoads) { if (bStore->buf() == aLoad->buf()) { if (!areEqual(bStore->indices(), aLoad->indices())) { if (isOverlapping(analyzer, bStore, aLoad)) { return true; } } } } } // WriteAfterWrite for (auto& aStore : aStores) { for (auto& bStore : bStores) { if (aStore->buf() == bStore->buf()) { if (!areEqual(aStore->indices(), bStore->indices())) { if (isOverlapping(analyzer, aStore, bStore)) { return true; } } } } } } } return false; } bool LoopNest::fuseLoops(const std::vector& loops, For** fused) { if (loops.empty()) { return false; } if (loops.size() == 1) { *fused = loops.front(); return true; } // Check if all the loops have the same parent. auto root = loops.front()->get_parent(); for (auto l : loops) { auto par = l->get_parent(); if (par == nullptr) { return false; } if (par != root) { return false; } } auto root_block = dynamic_cast(root); if (root_block == nullptr) { return false; } // Currently, we only handle cases where there are no statements between // the given loops in their parents body. We can possibly relax this // constraint by allowing statements that do not affect the loops being // fused by performing some dependency analysis. TODO. auto it = root_block->begin(); for (; it != root_block->end(); ++it) { if (*it == loops.front()) { break; } } TORCH_INTERNAL_ASSERT(it != root_block->end()); for (auto l : loops) { if (*it != l) { return false; } ++it; } // Check if bounds are the same for all the loops. auto first_loop = loops.front(); auto first_loop_start = IRSimplifier::simplify(first_loop->start()); auto first_loop_stop = IRSimplifier::simplify(first_loop->stop()); for (size_t i = 1; i < loops.size(); ++i) { auto curr_loop = loops[i]; auto curr_loop_start = IRSimplifier::simplify(curr_loop->start()); auto curr_loop_stop = IRSimplifier::simplify(curr_loop->stop()); if (!areEqual(curr_loop_start, first_loop_start)) { return false; } if (!areEqual(curr_loop_stop, first_loop_stop)) { return false; } } // A lambda to fuse all the given loops. auto fuse_all_loops = [](const std::vector& loops) { auto first_loop = loops.front(); // Fuse the loops by taking all the statements from the second loops // onwards and moving them into the first loop's body. // This way the final fused loop will be the same as the first loop. for (size_t i = 1; i < loops.size(); ++i) { auto body = dynamic_cast(Substitute( Stmt::clone(loops[i]->body()), {{loops[i]->var(), first_loop->var()}})); first_loop->body()->splice(first_loop->body()->end(), body); } }; // We need to check if fusing the loops results in a loop-carried dependence. // This check can be done only after the loops are fused into one. But if the // check is violated, we need to return the given loops in the original form. // So, we create a clone of all the loops, fuse them and check for this. std::vector loops_copy; loops_copy.reserve(loops.size()); for (const auto& l : loops) { loops_copy.push_back(dynamic_cast(Stmt::clone(l))); } fuse_all_loops(loops_copy); if (hasLoopCarriedDependence(loops_copy.front())) { return false; } // Now that all conditions are satisfied, we fuse the given loops. fuse_all_loops(loops); *fused = loops.front(); for (size_t i = 1; i < loops.size(); ++i) { root_block->remove_stmt(loops[i]); } return true; } For* findOuterFor(For* a, For* b) { Stmt* s = b; // guess b is the latter. while (s != nullptr) { if (s == a) { // yes, b is after a. return a; } s = s->get_parent(); } // check that the two are in the same loop nest. s = a; while (s != nullptr) { if (s == b) { // a is after b. return b; } s = s->get_parent(); } // a and b have no relationship. return nullptr; } void LoopNest::reorderAxis(For* a, For* b) { if (a == b) { // nothing to do. return; } // find inner and outer. For* outer = findOuterFor(a, b); if (outer == nullptr) { throw std::runtime_error("Reordered a loop not in LoopNest"); } For* inner = a == outer ? b : a; std::deque internal_axes; // Find relevant axes, store reversed. Stmt* s = inner; while (s != outer) { if (For* f = dynamic_cast(s)) { internal_axes.push_back(f); } // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) s = s->get_parent(); } internal_axes.push_back(outer); Block* root = dynamic_cast(outer->get_parent()); CHECK(root); // Do a shallow copy of the inner blocks. Block* body = new Block({}); body->splice(body->end(), inner->body()); For* before{outer}; For* after{nullptr}; For* last = internal_axes.front(); Stmt* newInner = body; s = inner; while (s != outer) { if (auto cond = dynamic_cast(s->get_parent())) { if (s == cond->true_stmt()) { newInner = cond->cloneWithNewBody(newInner); } else { // s is the false branch of Cond newInner = cond->cloneWithNewBodies(new Block({}), newInner); } } s = s->get_parent(); } // This is the major complexity in loop reordering: handling statements not in // the straight line of the reorder. To handle this we partition the tree into // the section before the critical path and after the critical path. // // An example of this pattern is: // for i in .. // Statement A // for j in .. // Statement B // Statement C // // When reordering loop i and j we need to ensure that Statement A and C are // still both executed with the loop extents of i, and that the three // statements are not reordered (as much as possible). for (auto* loop : internal_axes) { // If the inner loop had a component after the loop we must wrap it in a For // loop matching this level of the tree. if (after != nullptr) { after = loop->cloneWithNewBody(after); } bool pastMidpoint = false; bool hadBeforeStmts = false; for (auto I = loop->body()->begin(), E = loop->body()->end(); I != E;) { // Be careful not to invalidate the iterator. Stmt* s = *(I++); if (s == last) { // This is the midpoint. loop->body()->remove_stmt(s); if (!hadBeforeStmts) { // If there were no existing statements this loop does not need to be // preserved and we can roll it into the above loop. last = loop; } pastMidpoint = true; } else if (pastMidpoint) { // Statements after the reordered path must be moved to a new tree after // the reordered statement has occurred to preserve ordering. loop->body()->remove_stmt(s); if (after == nullptr) { after = loop->cloneWithNewBody(s); } else { after->body()->append_stmt(s); } } else { // We can leave any statements before the reordered loop alone, so long // as we preserve the loop structure. hadBeforeStmts = true; } } } // now we can actually reorder the chosen axes. std::swap(internal_axes.front(), internal_axes.back()); // Create the reordered internals: for (auto* loop : internal_axes) { newInner = loop->cloneWithNewBody(newInner); } // Append the new statements to the root of the tree. if (before->body()->nstmts() == 0) { // If the top level is now empty, eliminate it. root->replace_stmt(before, newInner); } else { root->insert_stmt_after(newInner, before); } if (after) { root->insert_stmt_after(after, newInner); } } bool isTrivialPermutation(const std::vector& permutation) { for (size_t i = 0; i < permutation.size(); ++i) { if (permutation[i] != i) { return false; } } return true; } bool isValidPermutation(std::vector permutation) { std::sort(permutation.begin(), permutation.end()); return isTrivialPermutation(permutation); } std::vector LoopNest::reorder( const std::vector& loops, const std::vector& permutation) { if (loops.size() != permutation.size()) { throw malformed_input("invalid permutation size"); } if (isTrivialPermutation(permutation)) { return loops; } if (!isValidPermutation(permutation)) { throw malformed_input("invalid permutation for reorder"); } if (loops.size() < 2) { return loops; } if (!areLoopsPerfectlyNested(loops)) { throw malformed_input("reorder is only allowed on perfectly nested loops"); } auto parent = dynamic_cast(loops.front()->get_parent()); if (parent == nullptr) { throw malformed_input("parent of the loops must be a Block"); } // Reorder the loops according to the permutation. std::vector result(loops.size()); for (size_t i = 0; i < loops.size(); ++i) { result[permutation[i]] = loops[i]; } // Remove the bodies from all the loops. auto innermost_body = loops.back()->removeBody(); // We use an empty block statement to replace the outermost loop // so that we know the position where the outermost reordered loop // is to be inserted. auto empty_block = new Block({}); parent->replace_stmt(loops.front(), empty_block); for (size_t i = 1; i < loops.size(); ++i) { auto block = dynamic_cast(loops[i]->get_parent()); TORCH_INTERNAL_ASSERT(block); block->remove_stmt(loops[i]); } // Set the new bodies after reorder for all the loops. for (size_t i = 0; i < result.size() - 1; ++i) { result[i]->setBody(result[i + 1]); } result.back()->setBody(innermost_body); parent->replace_stmt(empty_block, result.front()); return result; } bool LoopNest::areLoopsPerfectlyNested(const std::vector& loops) { if (loops.size() < 2) { return true; } for (size_t i = 0; i < loops.size() - 1; ++i) { auto loop_body = loops[i]->body(); if (loop_body->nstmts() != 1 || loop_body->front() != loops[i + 1]) { return false; } } return true; } void LoopNest::unroll(For* f, Stmt** unrolled) { Block* p = dynamic_cast(f->get_parent()); if (!f) { throw malformed_input("unroll attempted on null loop"); } else if (!p) { throw malformed_input("unroll attempted on loop with no parent"); } auto start_expr = IRSimplifier::simplify(f->start()); auto stop_expr = IRSimplifier::simplify(f->stop()); if (!start_expr->isConstant()) { throw std::runtime_error("Can't unroll due to non-constant loop start!"); } if (!stop_expr->isConstant()) { throw std::runtime_error("Can't unroll due to non-constant loop stop!"); } std::vector unrolled_stmts; int start_val = immediateAs(start_expr); int stop_val = immediateAs(stop_expr); for (int current = start_val; current < stop_val; ++current) { for (const auto stmt : f->body()->stmts()) { auto stmt_copy = Stmt::clone(stmt); unrolled_stmts.push_back(Substitute( stmt_copy, {{f->var(), getImmediateByType(f->var()->dtype(), current)}})); } } *unrolled = new Block(unrolled_stmts); *unrolled = IRSimplifier::simplify(*unrolled); p->replace_stmt(f, *unrolled); } void LoopNest::unroll(For* f) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Stmt* unrolled; unroll(f, &unrolled); } bool LoopNest::isNormalized(For* f) { if (f->start()->isConstant()) { return immediateAs(f->start()) == 0; } return false; } bool LoopNest::normalize(For* f) { if (!f) { throw malformed_input("normalize attempted on null loop"); } if (isNormalized(f)) { // No need to normalize anymore here. return false; } auto for_body_normalized = Substitute( f->body(), {{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}}); f->setBody(for_body_normalized); f->setStop(new Sub(f->stop(), f->start())); f->setStart(new IntImm(0)); return true; } // This function expects that there are 'num' loops perfectly nested within // and including 'f'. std::vector LoopNest::getLoopStmtsInLoopNest(For* f, size_t num) { std::vector loops(num); For* curr_for = f; loops[0] = curr_for; for (size_t i = 1; i < num; ++i) { TORCH_INTERNAL_ASSERT(curr_for->body()->nstmts() == 1); curr_for = dynamic_cast(curr_for->body()->front()); TORCH_INTERNAL_ASSERT(curr_for); loops[i] = curr_for; } return loops; } bool LoopNest::flatten(const std::vector& loops, For** flattened) { if (loops.empty()) { throw malformed_input("flatten attempted on empty set of loops"); } Block* p = dynamic_cast(loops[0]->get_parent()); if (!p) { throw malformed_input("flatten attempted on loops with no parent"); } if (loops.size() == 1) { // This loop nest is already flattened. *flattened = loops[0]; return false; } // Check if all the loops correspond to a perfect loopnest: // * every loop except the inner-most should have only one stmt, the For. // Do not flatten, otherwise. // This check also ensures we do not flatten reduction loops. for (size_t i = 0; i < loops.size() - 1; ++i) { if ((loops[i]->body()->nstmts() != 1) || (loops[i]->body()->front() != loops[i + 1])) { return false; } } // Normalize the loops before flattening. // We need to normalize them from inner-most to outer because once the outer // loop is normalized, the given pointers to inner loops point to old code. // For the same reason, we can't store the normalized inner loops until after // the outer-most loop is normalized. // NOLINTNEXTLINE(cppcoreguidelines-init-variables) for (size_t i = 0; i < loops.size(); ++i) { size_t idx = loops.size() - i - 1; LoopNest::normalize(loops[idx]); } // 'normalized' points to the outer-most loop in the normalized loopnest. // Collect all the normalized loops. // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) auto normalized_loops = getLoopStmtsInLoopNest(loops.front(), loops.size()); auto flat_var = new Var( normalized_loops[0]->var()->name_hint() + "_flat", normalized_loops[0]->var()->dtype()); VarMapping var_mapping; Expr* stop = new IntImm(1); for (size_t i = 0; i < normalized_loops.size(); ++i) { size_t idx = normalized_loops.size() - i - 1; auto curr_loop = normalized_loops[idx]; Expr* div = new Div(flat_var, stop); Expr* sub_expr = idx == 0 ? div : new Mod(div, curr_loop->stop()); var_mapping.push_back(std::make_pair(curr_loop->var(), sub_expr)); stop = new Mul(curr_loop->stop(), stop); } auto flattened_body = Substitute(normalized_loops.back()->removeBody(), var_mapping); normalized_loops.front()->setVar(flat_var); normalized_loops.front()->setStart(new IntImm(0)); normalized_loops.front()->setStop(stop); normalized_loops.front()->setBody(flattened_body); *flattened = normalized_loops.front(); return true; } bool LoopNest::flatten(const std::vector& loops) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* flattened; return flatten(loops, &flattened); } void LoopNest::compressBuffer(Buf* buf, Stmt* stmt) { if (buf->initializer()) { throw malformed_input("Can't compress buffer whose initializer is set"); } // Loop iterations in NNC IR do not follow sequential semantics by default. // In other words, the iterations of the loops could be executed in any // random order without affecting correctness. This constraint in turn // implies that there can’t be any *inter-iteration* dependences // (or *loop-carried* dependences) in NNC loops. So, any NNC IR with such // dependences is considered invalid. // // Given the constraint above, for any pair of accesses to a buffer (where // at least one of the access is a write), the accesses must be // loop-independent on the innermost loop containing the accesses as well as // all the loops above it. So, any dimension that uses only those loop // variables to access the given buffer could be optimized away. // // Algorithm: // * Find all the accesses to the given buf. (A) // * Find the parent common to all accesses in A. (P) // * Collect all the loops above P. (L) // * Collect all the loop variables corresponding to L. (LV) // * For every access a in A: // * For the index I in every dimension of a: // * If the variables in I are all in LV, mark this dimension // for deletion. // * For every dimension that is marked for deletion in ALL accesses in A: // * Update the buffer to set the size of that dimension to 1. // * Update all accesses in A to set the index in that dimension to 0. auto writes = WritesToBuf::find(stmt, buf); auto reads = StmtsReadingBuf::find(stmt, buf); // All buffers must be read and written at least once. // Is this a valid assumption? TODO TORCH_INTERNAL_ASSERT(!writes.empty()); TORCH_INTERNAL_ASSERT(!reads.empty()); // Find the parent common to all the buffer accesses. const Block* parent = dynamic_cast(writes.front()->get_parent()); TORCH_INTERNAL_ASSERT(parent); for (auto w : writes) { parent = Block::getSharedParent(parent, w); } for (auto r : reads) { parent = Block::getSharedParent(parent, r); } // Collect all the loops that are above the common parent. auto loops = LoopNest::getEnclosingLoopNest(parent); std::unordered_set loop_vars; for (auto l : loops) { loop_vars.insert(l->var()); } // TODO: Need to handle other Stmts / Exprs that read / write buffers. auto stores = NodeFinder::find(stmt); auto loads = NodeFinder::find(stmt); // Vector to indicate which dimensions could be compressed away. std::vector dims(buf->dims().size(), true); auto check_indices = [&](const std::vector& indices) { TORCH_INTERNAL_ASSERT(indices.size() == dims.size()); for (size_t i = 0; i < indices.size(); ++i) { auto index_vars = NodeFinder::find(indices[i]); for (auto iv : index_vars) { if (loop_vars.count(iv) == 0) { // A variable in this index is not in loop_vars. // This implies that this dimension cannot be optimized away. dims[i] = false; break; } } } }; for (auto s : stores) { if (s->buf() == buf) { check_indices(s->indices()); } } for (auto l : loads) { if (l->buf() == buf) { check_indices(l->indices()); } } bool any_dim_to_compress = false; for (auto d : dims) { any_dim_to_compress |= d; } if (!any_dim_to_compress) { return; } // Compress buffer by removing the marked dims. std::vector new_dims(buf->dims()); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { new_dims[i] = new IntImm(1); } } buf->set_dims(new_dims); // Modify all access to reflect the removed dims. auto get_new_indices = [&](const std::vector& indices) { TORCH_INTERNAL_ASSERT(indices.size() == dims.size()); std::vector new_indices(indices); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { new_indices[i] = new IntImm(0); } } return new_indices; }; for (auto s : stores) { if (s->buf() == buf) { s->set_indices(get_new_indices(s->indices())); } } for (auto l : loads) { if (l->buf() == buf) { l->set_indices(get_new_indices(l->indices())); } } } std::vector LoopNest::getLoopStmtsFor(Tensor* t) const { Stmt* cur_stmt = getLoopBodyFor(t); return getLoopStmtsFor(cur_stmt); } std::vector LoopNest::getLoopStmtsFor(const Buf* buf) const { Stmt* cur_stmt = getLoopBodyFor(buf); return getLoopStmtsFor(cur_stmt); } std::vector LoopNest::getLoopStmtsFor(Stmt* s) const { std::vector result; while (s) { if (auto* loop = dynamic_cast(s)) { result.push_back(loop); } s = s->get_parent(); } std::reverse(result.begin(), result.end()); return result; } void LoopNest::setGPUBlockIndex(For* f, int block_index) { f->set_gpu_block_index(block_index); } void LoopNest::setGPUThreadIndex(For* f, int thread_index) { f->set_gpu_thread_index(thread_index); } void LoopNest::setBufferMap( For* f, const std::unordered_map& map) { f->set_buffer_map(map); } Stmt* LoopNest::getLoopBodyFor(Tensor* t) const { return getLoopBodyFor(t->buf()); } Stmt* LoopNest::getLoopBodyFor(const Buf* buf) const { auto writes = WritesToBuf::find(root_stmt_, buf); // special case for reduction Tensors, ignore the initializer if it's the only // op: if (writes.size() == 2) { if (const Store* s = dynamic_cast(writes.back())) { if (const ReduceOp* r = dynamic_cast(s->value())) { return (Stmt*)s; // NOLINT } } } const Stmt* res = nullptr; for (const auto* s : writes) { if (!res) { res = s; continue; } res = Block::getSharedParent(res, s); } return (Stmt*)res; // NOLINT } For* LoopNest::getParentLoop(const Stmt* st) { if (st == nullptr) { return nullptr; } auto par = st->get_parent(); if (auto f = dynamic_cast(par)) { return f; } return getParentLoop(par); } std::vector LoopNest::getEnclosingLoopNest(const Stmt* st) { std::vector loops; auto f = getParentLoop(st); while (f) { loops.push_back(f); f = getParentLoop(f); } std::reverse(loops.begin(), loops.end()); return loops; } std::vector LoopNest::getAllWritesToBuf(const Buf* buf) const { return WritesToBuf::find(root_stmt_, buf); } std::vector LoopNest::getAllInnermostLoopsWritingToBuf( const Buf* buf) const { auto writes = getAllWritesToBuf(buf); std::vector innermost_loops; innermost_loops.reserve(writes.size()); for (auto w : writes) { innermost_loops.push_back(LoopNest::getParentLoop(w)); } return innermost_loops; } std::vector> LoopNest::getAllLoopNestsWritingToBuf( const Buf* buf) const { auto writes = getAllWritesToBuf(buf); std::vector> loopnests; loopnests.reserve(writes.size()); for (auto w : writes) { loopnests.emplace_back(LoopNest::getEnclosingLoopNest(w)); } return loopnests; } Stmt* LoopNest::simplify() { root_stmt_ = IRSimplifier::simplify(root_stmt_); return root_stmt_; } Stmt* FlattenIndexes(Stmt* s) { IndexFlattener idx_flattener; return idx_flattener.flatten(s); } // Auxiliary class for rewriting we're doing in `compute_at`. See // LoopNest::computeAt for more details. class LoopComputeAtRewriter : public IRMutator { public: LoopComputeAtRewriter( const Buf* buf, const Buf* new_buf, std::vector offsets) : buf_(buf), new_buf_(new_buf), offsets_(std::move(offsets)) {} private: const Buf* buf_; const Buf* new_buf_; std::vector offsets_; const Expr* mutate(const Load* v) override { if (v->buf() != buf_) { return v; } std::vector new_indices(v->indices().size()); for (const auto i : c10::irange(v->indices().size())) { new_indices[i] = IRSimplifier::simplify(new Sub(v->indices()[i], offsets_[i])); } return new Load(v->dtype(), new_buf_, new_indices); } }; static Store* getStoreStmtOfProducer(Stmt* s) { if (Store* st = dynamic_cast(s)) { return st; } if (Block* b = dynamic_cast(s)) { for (Stmt* ss : *b) { if (Store* st = dynamic_cast(ss)) { return st; } } } return nullptr; } static std::vector getOuterLoopIndexes(Stmt* s) { std::vector res; Stmt* cur = s; while (cur) { if (auto l = dynamic_cast(cur)) { res.push_back(l->var()); } cur = cur->get_parent(); } return res; } class CacheReplacer : public IRMutator { public: CacheReplacer( const Buf* buffer, const Buf* cache, std::vector& offsets) : buf_(buffer), cache_(cache), offsets_(offsets) {} private: const Expr* mutate(const Load* v) override { const Buf* buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } // Map indices to call-parameters. std::vector newIndices; TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); for (size_t i = 0; i < v->indices().size(); ++i) { const Expr* index = v->indices()[i]->accept_mutator(this); const Expr* offset = offsets_[i]; const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); newIndices.push_back(sub); } return new Load(cache_, newIndices); } Stmt* mutate(const Store* v) override { const Buf* buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } const Expr* newValue = v->value()->accept_mutator(this); // Map indices to call-parameters. std::vector newIndices; TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); for (size_t i = 0; i < v->indices().size(); ++i) { const Expr* index = v->indices()[i]->accept_mutator(this); const Expr* offset = offsets_[i]; const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); newIndices.push_back(sub); } return new Store(cache_, newIndices, newValue); } const Buf* buf_; const Buf* cache_; std::vector& offsets_; }; LoopNest::AccessResult LoopNest::cacheAccesses( const Buf* producer, const std::string& name, Stmt* consumer) { const ReduceOp* reduceOp{nullptr}; auto stores = NodeFinder::find(consumer); for (auto* store : stores) { if (auto ro = dynamic_cast(store->value())) { if (store->buf() != producer) { continue; } if (reduceOp) { throw std::runtime_error( "can only cache accesses used by at most a single reduceOp"); return {nullptr, nullptr}; } reduceOp = ro; } } // Check bounds but don't care about AccessKind. auto consumer_bounds_info = inferBounds(consumer, false); auto bounds_it = consumer_bounds_info.find(producer); if (bounds_it == consumer_bounds_info.end()) { throw std::runtime_error("consumer does not use the Tensor produced"); return {nullptr, nullptr}; } TORCH_INTERNAL_ASSERT(bounds_it->second.size() == 1); TensorAccessBoundsInfo& info = bounds_it->second[0]; bool hasReads = info.kind == kLoad || info.kind == kMutate; bool hasWrites = info.kind == kStore || info.kind == kMutate; std::vector var_names = {"i", "j", "k", "l", "m", "n", "o", "p"}; std::vector tmp_dims; std::vector new_loop_vars; std::vector new_loop_vars_expr; // Determine the size of the cache, and create a loop var for each dimension. for (size_t i = 0; i < info.start.size(); ++i) { const Expr* dim = IRSimplifier::simplify( new Add(new Sub(info.stop[i], info.start[i]), new IntImm(1))); tmp_dims.push_back(dim); new_loop_vars.push_back(new Var(var_names[i % var_names.size()], kInt)); new_loop_vars_expr.push_back(new_loop_vars[i]); } // Create the var. Buf* tmp_buf = new Buf(new Var(name, kHandle), tmp_dims, producer->dtype()); // determine the offsets for calls into the cache based off the loop start of // each axis. std::vector tmp_params; for (size_t i = 0; i < new_loop_vars.size(); ++i) { tmp_params.push_back(new Add(new_loop_vars[i], info.start[i])); } // Replace acceses to the producer in the consumer with the cache. CacheReplacer replacer(producer, tmp_buf, info.start); Stmt* new_consumer = IRSimplifier::simplify(consumer->accept_mutator(&replacer)); // replace the old consumer with the replaced consumer. Block* consumer_block = nullptr; // if the consumer is a block, we should mutate it in place. if ((consumer_block = dynamic_cast(consumer))) { consumer_block->clear(); consumer_block->append_stmt(new_consumer); } else { consumer_block = dynamic_cast(consumer->get_parent()); assert(consumer_block); consumer_block->replace_stmt(consumer, new_consumer); } // If there's a reduction and we are operating on the reduce axis, we need to // initialize the cache with 0s. Also, we can't just write the result straight // back to the original buffer, since after parallelism the writes will race. // Instead we need to create a new ReduceOp. bool on_reduce_axis = false; if (reduceOp) { std::set reduce_args( reduceOp->reduce_args().begin(), reduceOp->reduce_args().end()); std::set enclosing_vars; for (auto enclosing_for_stmt : NodeFinder::find(consumer)) { enclosing_vars.insert(enclosing_for_stmt->var()); } for (auto reduce_arg : reduce_args) { if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) { on_reduce_axis = true; } } } if (reduceOp && on_reduce_axis) { // reduceOp means we had both loads and stores. // Init cache to 0. Stmt* tmp_init = new Store( tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_init = new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_init); } consumer_block->insert_stmt_before(tmp_init, new_consumer); // Reduce back to the original buffer: Stmt* tmp_store = new Store( producer, tmp_params, reduceOp->reducer()( producer, ExprHandle(new Load(tmp_buf, new_loop_vars_expr)), tmp_params, {})); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); } consumer_block->insert_stmt_after(tmp_store, new_consumer); return std::make_pair(tmp_buf, new_consumer); } if (hasReads) { // Fill the cache with values from the consumer. Stmt* tmp_store = new Store(tmp_buf, new_loop_vars_expr, new Load(producer, tmp_params)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); } consumer_block->insert_stmt_before(tmp_store, new_consumer); } if (hasWrites) { // sync the cache back to the producer buf. Stmt* tmp_store = new Store(producer, tmp_params, new Load(tmp_buf, new_loop_vars_expr)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); } consumer_block->insert_stmt_after(tmp_store, new_consumer); } return std::make_pair(tmp_buf, new_consumer); } /* * WHAT COMPUTE_AT DOES * ==================== * * Suppose we have two loops: * * for i in 0..100: * for j in 0..200: * A[i,j] = sin(i*j) * for i in 0..100: * for j in 0..199: * B[i,j] = A[i,j] + A[i, j+1] * * If we compute these loops as is, we would have to allocate two buffers: * 100x200 for A and 100x199 for B. To decrease the memory usage one can use * compute_inline primitive, which would result in the following: * * for i in 0..100: * for j in 0..199: * B[i,j] = sin(i*j) + sin(i*(j+1)) * * We now need only one buffer - 100x199 for B. However, we're now doing some * redundant computations: we're calling `sin` twice as much as in the first * version. * * Ultimately, we nede to choose at what point we prefer to compute values of * A[i,j] - we can do it in the very beginning for the entire buffer A (the * first option) or compute it on the fly when we compute B (the second option). * There are also options in between those two: we can compute a part of B which * is required for a computation of part of B, e.g. for a single row of B. The * code would then look like: * * for i in 0..100: * for j in 0..200: * A[j] = sin(i*j) * for j in 0..199: * B[i,j] = A[j] + A[j+1] * * In this case we're only using 1x200 for A, and we're avoiding redundant * computations. * * The purpose of `compute_at` is to achieve exactly this transformation. * * compute_at requires to specify What to compute and Where to compute: in our * example we would call compute_at(What=`A[i,j] = sin(i*j)`, Where=`for i in * 0..100`). * * More info about compute_at could be found in Halide's tutorials: * https://halide-lang.org/tutorials/tutorial_lesson_08_scheduling_2.html * * HOW COMPUTE_AT WORKS * ==================== * * The most important part of compute_at is bounds inference: we need to figure * out what part of the used tensors we need to compute when we move the * computation to a new scope. In the example above, we need bounds inference to * tell us that in order to compute A at each iteration of the outer loop, we * need to compute A within indices [i:i+1,0:200]. * * This info allows us to conclude that we need a temp buffer of size 1x200. * * Once this is known we need to insert statements for allocation and freeing * the temporary buffer and copy the original computation to fill the temp * buffer with proper values. When we copy the computation we also must rewrite * indices used in it: old indices are referring to the old loop and are not * valid in the new loop. * * To easier follow the logic, let's examine an example. Suppose we start from * the following loop nest: * for py in 0..100: * for px in 0..100: * producer[py,px] = py*px * for cy in 0..100: * for cx in 0..100: * consumer[cy,cx] = producer[cy,cx] * * And then we're running `compute_at(producer, cy)`. * * What we would like to get is the following loop nest: * for py in 0..100: * for px in 0..100: * producer[py,px] = py*px * for cy in 0..100: * Allocate(temp, {1, 100}) * for ty in 0..1: * for tx in 0..100: * temp[ty,tx] = (ty+cy)*(tx+0) * for cx in 0..100: * consumer[cy,cx] = temp[0,cx] * Free(temp) * * NB: this loop nest can and should be simplified (e.g. the producer loop can * be removed since its result is no longer used), but this clean-up * optimization is performed separately (currently, not performed at all). * * If we examine the final loop nest, we can identify that the following steps * needs to be performed: * - Bounds inference needs to tell us that we need a 1x100 buffer for temp. * - Allocate and Free statements for this buffer need to be inserted to the * loop. * - A new loop-nest should be inserted to the loop CY for computing `temp` * and it should replicate the loopnest of producer (PY,PX loops). The indices * in the loop body need to be offset by (cy, 0) - the offsets come from * bounds inference too. * - The computation of `consumer` needs to be rewritten so that it uses * `temp` instead of `producer`. The indices in the corresponding accesses * also need to be offset. */ void LoopNest::computeAt(Stmt* s, For* f) { Store* st = getStoreStmtOfProducer(s); if (!st) { return; } // Infer bounds info for all accesses that we make in the loop auto loop_bounds_info = inferBounds(f->body()); // bounds_it holds bounds info for the store we're trying to move to // the loop. If its result isn't accessed in the loop at all - do nothing and // exit early. auto bounds_it = loop_bounds_info.find(st->buf()); if (bounds_it == loop_bounds_info.end()) { return; } // Compute dimensions of the temp buffer we would need to allocate std::vector dims = getBoundExtents(bounds_it->second); // TODO: Use name-hint of the producer instead of "temp" const Buf* temp_buf = new Buf("temp", dims, st->value()->dtype()); // Generate index variables for 'temp' std::vector temp_indices(dims.size()); for (const auto i : c10::irange(dims.size())) { // TODO: Use name-hint of the producer indices instead of 'idx' temp_indices[i] = new Var(std::string("idx") + c10::to_string(i), kInt); } // Prepare substitute rules for constructing the temp statement from the prod // statement // TODO: Instead of going up the loop nest we should go through the indices in // the original tensor expression. The loops in the nest might've been // modified (e.g. split or merged) so that the loop indices no longer // correspond to the indices of the original expression and even their number // might be different. In that case, the loop below would crash. std::vector prod_indices = getOuterLoopIndexes(s); std::vector> rewrite_indices_map; std::vector offsets; for (const TensorAccessBoundsInfo& p : bounds_it->second) { for (const auto i : c10::irange(p.start.size())) { if (offsets.size() <= i) { offsets.push_back(p.start[i]); } else { offsets[i] = IRSimplifier::simplify(new Min(offsets[i], p.start[i], true)); } } } for (const auto i : c10::irange(prod_indices.size())) { rewrite_indices_map.push_back( {prod_indices[i], new Add(temp_indices[i], offsets[i])}); } // Construct the temp statement Stmt* bd = new Store( temp_buf, temp_indices, Substitute(st->value(), rewrite_indices_map)); // Construct the loop nest for the temp computation for (const auto i : c10::irange(dims.size())) { // We're creating loops from innermost to outermost, so we need to access // dimensions in reversed order. size_t dim_idx = dims.size() - 1 - i; bd = new For( dynamic_cast(temp_indices[dim_idx]), new IntImm(0), dims[dim_idx], bd); } // Add constructed stmts to the consumer loop f->body()->prepend_stmt(bd); // Rewrite accesses to producer in consumer with accesses to temp LoopComputeAtRewriter lr(st->buf(), temp_buf, offsets); Stmt* new_f = f->accept_mutator(&lr); if (f != new_f) { Block* bb = dynamic_cast(f->get_parent()); bb->replace_stmt(f, new_f); } } class RfactorStoreRewriter : public IRMutator { public: RfactorStoreRewriter( const Buf* old_buf, const std::vector& old_indices, const Buf* new_buf, const Var* reduction_var) : old_buf_(old_buf), old_indices_(old_indices), new_buf_(new_buf), reduction_var_(reduction_var), new_indices_(old_indices) { new_indices_.push_back(reduction_var_); } const Expr* mutate(const Load* v) override { if (v->buf() != old_buf_) { return IRMutator::mutate(v); } TORCH_INTERNAL_ASSERT(old_indices_.size() == v->indices().size()); bool equal_indices = true; for (size_t i = 0; i < v->indices().size(); ++i) { if (!exprEquals(v->indices()[i], old_indices_[i])) { equal_indices = false; break; } } if (!equal_indices) { return IRMutator::mutate(v); } return new Load(new_buf_, new_indices_); } const Expr* mutate(const ReduceOp* v) override { const Expr* body_new = v->body()->accept_mutator(this); std::vector new_reduce_args; for (auto* r : v->reduce_args()) { if (r != reduction_var_) { new_reduce_args.push_back(r); } } return new ReduceOp(body_new, new_reduce_args, v->reducer()); } Stmt* mutate(const Store* v) override { if (v->buf() != old_buf_) { return IRMutator::mutate(v); } TORCH_INTERNAL_ASSERT(old_indices_.size() == v->indices().size()); bool equal_indices = true; for (size_t i = 0; i < v->indices().size(); ++i) { if (!exprEquals(v->indices()[i], old_indices_[i])) { equal_indices = false; break; } } if (!equal_indices) { return IRMutator::mutate(v); } const Expr* new_value = v->value()->accept_mutator(this); return new Store(new_buf_, new_indices_, new_value); } private: const Buf* old_buf_; const std::vector& old_indices_; const Buf* new_buf_; const Var* reduction_var_; std::vector new_indices_; }; bool LoopNest::rfactor(Stmt* st, For* target_for) { Buf* tmp_buf = nullptr; return rfactor(st, target_for, &tmp_buf); } bool LoopNest::rfactor(Stmt* st, For* outer_reduction_for, Buf** rfac_buf_ptr) { Store* reduction_store = dynamic_cast(st); const ReduceOp* reduce_op = dynamic_cast(reduction_store->value()); if (!reduce_op) { // Not a reduction store return false; } auto orig_buf = reduction_store->buf(); auto orig_buf_indices = reduction_store->indices(); const Var* reduction_var = outer_reduction_for->var(); std::set reduce_args = { reduce_op->reduce_args().begin(), reduce_op->reduce_args().end()}; if (reduce_args.size() < 2) { // Not enough reduction axis to do rfactor return false; } // Verify that outer_reduction_for is a perfect loop nest with all loops being // reductions Stmt* cur = outer_reduction_for; while (For* cur_for = dynamic_cast(cur)) { if (!reduce_args.count(cur_for->var())) { // output axis inside outer_reduction_for are not allowed return false; } reduce_args.erase(cur_for->var()); Block* b = cur_for->body(); if (b->nstmts() != 1) { return false; } cur = b->stmts().front(); } if (cur != st) { // The reduction store is not a single stmt in the innermost loop - bail in // that case return false; } if (!reduce_args.empty()) { // This is not the outermost reduction axis return false; } // assert: reduce_axis match loop vars from outer_reduction_for and inside // assert: no other stmts in outer_reduction_for or its child loops std::vector rfac_dims = orig_buf->dims(); const Expr* extra_dim = IRSimplifier::simplify( new Sub(outer_reduction_for->stop(), outer_reduction_for->start())); rfac_dims.push_back(extra_dim); const Expr* rfac_init = new Cast(reduce_op->dtype(), reduce_op->reducer().initializer()); *rfac_buf_ptr = new Buf( orig_buf->name_hint() + "_rfac", rfac_dims, reduce_op->dtype(), rfac_init); Buf* rfac_buf = *rfac_buf_ptr; // Rewrite the original reduction store to use the temporary rfac buffer: // 1) X[*indexes] --> T[*indexes + {reduction_var}] // 2) reduce_axis -= {reduction_var} RfactorStoreRewriter rfac_rewriter( orig_buf, orig_buf_indices, rfac_buf, reduction_var); dynamic_cast(st->get_parent()) ->replace_stmt(st, st->accept_mutator(&rfac_rewriter)); // Insert a store for the final reduction over the temp buffer into the // original buffer: // X[*indexes] = ReduceOp(X[*indexes] + T[*indexes + {reduction_var}], // reduce_axis={reduction_var}) Block* b = outer_reduction_for->body(); TORCH_INTERNAL_ASSERT(b->nstmts() == 1); Stmt* first_reduction_loop = b->stmts().front(); auto rfac_buf_indices = orig_buf_indices; rfac_buf_indices.emplace_back(reduction_var); const Expr* final_reduce_load = new Load(rfac_buf, rfac_buf_indices); outer_reduction_for->body()->insert_stmt_after( new Store( orig_buf, orig_buf_indices, reduce_op->reducer()( orig_buf, final_reduce_load, orig_buf_indices, {reduction_var})), first_reduction_loop); // Insert an initialization store for the temp buffer: // T[a,b,c] = init outer_reduction_for->body()->insert_stmt_before( new Store(rfac_buf, rfac_buf_indices, rfac_init), first_reduction_loop); return true; } } // namespace tensorexpr } // namespace jit } // namespace torch