#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit::tensorexpr { LoopNest::LoopNest(const LoopNest& other) : root_stmt_(Stmt::clone(other.root_stmt_)), output_bufs_(other.output_bufs_) { GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_)); verify(root_stmt_); } LoopNest::LoopNest(StmtPtr stmt, std::unordered_set output_bufs) : root_stmt_(std::move(stmt)), output_bufs_(std::move(output_bufs)) { GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_)); verify(root_stmt_); } LoopNest::LoopNest( const std::vector& output_tensors, const std::vector& tensors_to_compute) { initialize(output_tensors, tensors_to_compute); GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_)); verify(root_stmt_); } LoopNest::LoopNest(const std::vector& output_tensors) { initialize(output_tensors, output_tensors); GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_)); verify(root_stmt_); } std::vector LoopNest::getIntermediateBufs() const { std::vector result; std::unordered_set result_set; auto input_bufs = getInputBufs(); auto bufs = NodeFinder::find(root_stmt_); for (const auto& buf : bufs) { if (!output_bufs_.count(buf) && !input_bufs.count(buf) && !result_set.count(buf)) { result.push_back(buf); result_set.insert(buf); } } return result; } const std::unordered_set LoopNest::getInputBufs() const { std::unordered_set result; auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); for (auto& kv : buf_load_store_uses) { bool has_store = false; for (auto& use : kv.second) { if (use.isStore) { has_store = true; break; } } if (!has_store) { result.insert(kv.first); } } return result; } class IndexFlattener : public IRMutator { public: StmtPtr flatten(const StmtPtr& s) { return s->accept_mutator(this); } ExprPtr mutate(const LoadPtr& v) override { if (v->indices().size() == 1) { return v; } return alloc( v->dtype(), v->buf(), std::vector({flatten_index( v->buf()->dims(), v->indices(), v->buf()->strides())})); } StmtPtr mutate(const StorePtr& v) override { ExprPtr value = v->value(); ExprPtr new_value = value->accept_mutator(this); if (v->indices().size() == 1 && value == new_value) { return v; } std::vector indices = { flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides())}; v->set_indices(indices); v->set_value(new_value); return v; } }; static bool isValidIdentifierChar(char c, size_t pos) { return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c)); } // replaces all invalid characters with underscore std::string sanitizeName(const std::string& input_name) { std::stringstream sanitized_name; for (size_t i = 0; i < input_name.size(); ++i) { if (isValidIdentifierChar(input_name[i], i)) { sanitized_name << input_name[i]; } else { if (i == 0) { // Don't start names with underscore sanitized_name << "v"; } sanitized_name << "_"; } } return sanitized_name.str(); } class VarNameSanitizer : public IRMutator { public: ExprPtr mutate(const BufPtr& v) override { if (seen_bufs_.count(v)) { return v; } const std::string& name = v->name_hint(); auto new_name = sanitizeName(name); if (taken_names_.count(new_name)) { new_name = getNextAvailableName(new_name); } v->set_name_hint(new_name); taken_names_.insert(new_name); seen_bufs_.insert(v); return v; } ExprPtr mutate(const VarPtr& v) override { if (seen_vars_.count(v)) { return v; } const std::string& name = v->name_hint(); auto new_name = sanitizeName(name); if (taken_names_.count(new_name)) { new_name = getNextAvailableName(new_name); } v->set_name_hint(new_name); taken_names_.insert(new_name); seen_vars_.insert(v); return v; } StmtPtr mutate(const ForPtr& v) override { auto new_name = getNextAvailableName(getIndexVarNameAtLevel(level_)); if (seen_index_vars_.count(v->var())) { auto new_var = alloc("", v->var()->dtype()); Substitute(v, {{v->var(), new_var}}); } v->var()->set_name_hint(new_name); seen_index_vars_.insert(v->var()); seen_vars_.insert(v->var()); taken_names_.insert(new_name); level_++; v->body()->accept_mutator(this); level_--; v->start()->accept_mutator(this); v->stop()->accept_mutator(this); return v; } std::string getIndexVarNameAtLevel(int level_) { auto names_num = index_var_names_.size(); auto counter = level_ / names_num; if (counter == 0) { return index_var_names_[level_ % names_num]; } else { return index_var_names_[level_ % names_num] + std::to_string(counter); } } std::string getNextAvailableName(const std::string& base_name) { std::string name = base_name; int counter = 0; while (taken_names_.count(name)) { counter++; name = base_name + "_" + std::to_string(counter); } return name; } private: std::vector index_var_names_ = {"i", "j", "k", "l", "m", "n", "o", "p"}; std::unordered_set taken_names_; std::unordered_set seen_index_vars_; std::unordered_set seen_vars_; std::unordered_set seen_bufs_; int level_ = 0; }; StmtPtr LoopNest::sanitizeNames(StmtPtr s) { VarNameSanitizer r; s->accept_mutator(&r); return s; } class Vectorizer : public IRMutator { public: StmtPtr vectorize(ForPtr v) { StmtPtr body = v->body(); VarPtr var = v->var(); ExprPtr start = v->start(); ExprPtr stop = v->stop(); auto start_imm = intValue(start); auto stop_imm = intValue(stop); if (!start_imm) { // Can't vectorize due to non-constant loop start! success_ = false; return v; } if (!stop_imm) { // Can't vectorize due to non-constant loop stop! success_ = false; return v; } var_ = var; start_ = immLike(start, *start_imm); lanes_ = *stop_imm; StmtPtr new_body = body->accept_mutator(this); if (new_body == body) { // Vectorization failed! success_ = false; return v; } return new_body; } bool success() const { return success_; } ExprPtr mutate(const AddPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) + ExprHandle(inputs[1]); }); } ExprPtr mutate(const SubPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) - ExprHandle(inputs[1]); }); } ExprPtr mutate(const MulPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) * ExprHandle(inputs[1]); }); } ExprPtr mutate(const DivPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) / ExprHandle(inputs[1]); }); } ExprPtr mutate(const ModPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) % ExprHandle(inputs[1]); }); } ExprPtr mutate(const AndPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) & ExprHandle(inputs[1]); }); } ExprPtr mutate(const OrPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) | ExprHandle(inputs[1]); }); } ExprPtr mutate(const XorPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) ^ ExprHandle(inputs[1]); }); } ExprPtr mutate(const LshiftPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) << ExprHandle(inputs[1]); }); } ExprPtr mutate(const RshiftPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return ExprHandle(inputs[0]) >> ExprHandle(inputs[1]); }); } ExprPtr mutate(const MaxPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return Max::make( ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans()); }); } ExprPtr mutate(const MinPtr& v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { return Min::make( ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans()); }); } ExprPtr mutate(const CompareSelectPtr& v) override { std::vector inputs = { v->lhs(), v->rhs(), v->ret_val1(), v->ret_val2()}; return try_vectorize(v, inputs, [&]() { return CompareSelect::make( ExprHandle(inputs[0]), ExprHandle(inputs[1]), ExprHandle(inputs[2]), ExprHandle(inputs[3]), v->compare_select_op(), v->bias()); }); } ExprPtr mutate(const BitCastPtr& v) override { std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { return BitCast::make( Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); }); } ExprPtr mutate(const CastPtr& v) override { std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { return Cast::make( Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); }); } ExprPtr mutate(const VarPtr& v) override { if (v == var_) { return Ramp::make( ExprHandle(start_), ExprHandle(immLike(start_, 1)), lanes_) .node(); } return v; } ExprPtr mutate(const RampPtr& v) override { ExprPtr base = v->base(); ExprPtr stride = v->stride(); ExprPtr base_new = base->accept_mutator(this); ExprPtr stride_new = stride->accept_mutator(this); if (base_new == base && stride_new == stride) { return v; } // Can't vectorize a Ramp! success_ = false; return v; } ExprPtr mutate(const LoadPtr& v) override { Dtype dtype(v->dtype().scalar_type(), lanes_); BufPtr buf = v->buf(); std::vector inputs = {v->flat_index()}; return try_vectorize(v, inputs, [&]() { return Load::make(dtype, BufHandle(buf), {ExprHandle(inputs[0])}); }); } ExprPtr mutate(const ReduceOpPtr& v) override { Dtype dtype(v->dtype().scalar_type(), lanes_); std::vector inputs = {v->body()}; auto out = try_vectorize(v, inputs, [&]() { return ExprHandle( alloc(inputs[0], v->reduce_args(), v->reducer())); }); return out; } ExprPtr mutate(const BroadcastPtr& v) override { ExprPtr val = v->value(); ExprPtr new_val = val->accept_mutator(this); if (new_val == val) { return v; } // Can't vectorize a Broadcast! success_ = false; return v; } ExprPtr mutate(const IfThenElsePtr& v) override { ExprPtr condition = v->condition(); ExprPtr new_condition = condition->accept_mutator(this); if (new_condition != condition) { // Can't vectorize an IfThenElse condition! success_ = false; return v; } std::vector inputs = {v->true_value(), v->false_value()}; return try_vectorize(v, inputs, [&]() { return IfThenElse::make( ExprHandle(condition), ExprHandle(inputs[0]), ExprHandle(inputs[1])); }); } ExprPtr mutate(const IntrinsicsPtr& v) override { std::vector inputs = v->params(); return try_vectorize(v, inputs, [&]() { return ExprHandle(alloc(v->op_type(), inputs)); }); } StmtPtr mutate(const StorePtr& v) override { BufPtr buf = v->buf(); std::vector inputs = {v->flat_index(), v->value()}; return try_vectorize(v, inputs, [&]() { return Store::make( BufHandle(buf), {ExprHandle(inputs[0])}, ExprHandle(inputs[1])); }); } StmtPtr mutate(const ForPtr& v) override { VarPtr var = v->var(); ExprPtr start = v->start(); ExprPtr stop = v->stop(); LoopOptions loop_options = v->loop_options(); ExprPtr new_start = start->accept_mutator(this); ExprPtr new_stop = stop->accept_mutator(this); if (new_start != start || new_stop != stop) { // Can't vectorize nested For with dependent loop bounds! success_ = false; return v; } StmtPtr body = v->body(); StmtPtr new_body = body->accept_mutator(this); if (new_body == body) { return (ForPtr)v; } return alloc(var, new_start, new_stop, new_body, loop_options); } StmtPtr mutate(const BlockPtr& v) override { // IRMutator does in-place mutations. But the logic in vectorization checks // for success by looking for a new stmt. So, we override the in-place // mutations and create a clone here if any of its statements change. // TODO: Can we change the logic of vectorizer so that we don't need this? bool any_change = false; std::vector stmts; for (const StmtPtr& stmt : *v) { StmtPtr stmt_new = stmt->accept_mutator(this); if (stmt != stmt_new) { any_change = true; } else { stmt_new = Stmt::clone(stmt); } if (stmt_new) { stmts.push_back(stmt_new); } } if (any_change) { return alloc(stmts); } return v; } template ExprPtr try_vectorize(ExprPtr e, std::vector& inputs, T&& vec_ctor) { bool vectorize = vectorize_inputs(inputs); if (vectorize) { return vec_ctor().node(); } return e; } template StmtPtr try_vectorize(StmtPtr s, std::vector& inputs, T&& vec_ctor) { bool vectorize = vectorize_inputs(inputs); if (vectorize) { return vec_ctor(); } return s; } bool vectorize_inputs(std::vector& inputs) { bool any_vectorized = false; std::vector new_inputs; // Attempt to vectorize each input. for (ExprPtr& in : inputs) { ExprPtr new_in = in->accept_mutator(this); new_inputs.push_back(new_in); if (new_in != in) { any_vectorized = true; } } // If none of them vectorized, then don't vectorize this. if (!any_vectorized) { return false; } // Insert broadcasts for any inputs that weren't vectorized. for (size_t i = 0; i < inputs.size(); ++i) { if (inputs[i] == new_inputs[i]) { inputs[i] = Broadcast::make(ExprHandle(inputs[i]), lanes_).node(); } else { inputs[i] = new_inputs[i]; } } // And then vectorize this node. return true; } VarPtr var_ = nullptr; int64_t lanes_ = 0; ExprPtr start_ = nullptr; bool success_ = true; }; bool LoopNest::vectorize(const ForPtr& f) { BlockPtr b = to(f->get_parent()); if (!b) { return false; } // Can't vectorize reduction axes. auto reductions = NodeFinder::find(f); for (const auto& r : reductions) { if (std::find(r->reduce_args().begin(), r->reduce_args().end(), f->var()) != r->reduce_args().end()) { return false; } } Vectorizer v; StmtPtr new_f = nullptr; new_f = Stmt::clone(f); normalize(to(new_f)); new_f = FlattenIndexes(new_f); new_f = v.vectorize(to(new_f)); if (!v.success()) { // We clone f before vectorizing. So, any partial vectorization will // have modified the clone. In case of an exception, we can continue // using f. new_f = f; } if (new_f != f) { b->replace_stmt(f, IRSimplifier::simplify(new_f)); return true; } // Vectorization was not successful. return false; } void LoopNest::initialize( const std::vector& output_tensors, const std::vector& tensors_to_compute) { for (const auto& t : output_tensors) { output_bufs_.insert(t.buf()); } std::vector loops; for (const Tensor& t : tensors_to_compute) { StmtPtr loop = t.stmt(); if (loop->get_parent()) { std::cerr << "Error: creating a loopnest from already used Tensors\n"; loops = {}; break; } // Flatten initializers. if (BlockPtr block = to(loop)) { for (const auto& s : block->stmts()) { block->remove_stmt(s); loops.push_back(s); } } else { loops.push_back(loop); } } root_stmt_ = alloc(loops); } class FunctionInliner : public IRMutator { public: FunctionInliner(StorePtr producer, std::unordered_set outputs) : buf_(producer->buf()), producer_(std::move(producer)), outputs_(std::move(outputs)) { for (const auto& i : producer_->indices()) { if (auto index_var = to(i)) { index_vars_.insert(index_var); producer_index_vars_.push_back(index_var); } else { // If the index can be a constant, then that dimension must have size 1 // (since we don't support in-place writes). Resolves issue 52581. auto index_val = evalInt(i); if (!index_val || *index_val != 0) { success_ = false; break; } producer_index_vars_.push_back(nullptr); } } } bool success() const { return success_; } private: ExprPtr mutate_loads(const BufPtr& buf, std::vector dims) { std::vector index_vars; if (buf->ndim() != producer_index_vars_.size()) { // Dimensions of producer and consumer expressions do not match in inliner // in the fuser success_ = false; return nullptr; } for (const auto i : c10::irange(buf->ndim())) { VarPtr func_callee_arg = producer_index_vars_.at(i); ExprPtr func_caller_param = dims.at(i); if (func_callee_arg == nullptr) { continue; } auto iter = inline_mapping_.find(func_callee_arg); if (iter != inline_mapping_.end()) { // Duplicated variables success_ = false; return nullptr; } // Add a mapping for each function parameter to it's source name. inline_mapping_[func_callee_arg] = func_caller_param; GRAPH_DEBUG( "ComputeInline: Inline mapping: ", std::to_string(func_callee_arg), " -> ", std::to_string(func_caller_param)); index_vars.push_back(func_callee_arg); } // Call the actual replacement. ExprPtr body = producer_->value(); GRAPH_DEBUG("ComputeInline: Before rewriting body: ", std::to_string(body)); ExprPtr result = Expr::clone(body)->accept_mutator(this); GRAPH_DEBUG( "ComputeInline: After rewriting body: ", std::to_string(result)); // Remove the mappings we created for this function parameters. for (const auto& v : index_vars) { for (auto& pair : random_bindings_) { if (pair.second.erase(v)) { ExprPtr inlined = inline_mapping_[v]; for (const auto& nv : VarFinder::find(inlined)) { pair.second.insert(nv); } } } GRAPH_DEBUG("ComputeInline: Inline mapping: erasing", std::to_string(v)); inline_mapping_.erase(v); } return result; } ExprPtr mutate(const LoadPtr& v) override { if (!success()) { return v; } BufPtr buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } if (v->indices().size() != buf->ndim()) { // Number of indices doesn't match buf rank in the fuser success_ = false; return v; } auto result = mutate_loads(buf, v->indices()); if (!result) { // If we don't inline successfully return the given load. success_ = false; return v; } return result; } // Replace the target variable with the caller expressions. ExprPtr mutate(const VarPtr& v) override { if (!success()) { return v; } auto iter = inline_mapping_.find(v); if (iter == inline_mapping_.end()) { return v; } else { ExprPtr expr = iter->second; // Continue to transform the value from the lookup table. return expr->accept_mutator(this); } } // Handle random intrinsics which should be cached. ExprPtr mutate(const IntrinsicsPtr& v) override { if (!success()) { return v; } if (!in_producer_ || v->op_type() != kRand) { return IRMutator::mutate(v); } // Create a new Let Statement for the random variable, which we can refer // to multiple times and resolve the same value (ie. store it in a scalar // rather than the Tensor). const std::string& name = buf_->name_hint(); VarPtr new_var = alloc(name, v->dtype()); random_bindings_[alloc(new_var, v)] = index_vars_; GRAPH_DEBUG( "ComputeInline: created random bindings for ", std::to_string(new_var)); return new_var; } // Remove the buffer write from the inlined function. StmtPtr mutate(const StorePtr& v) override { if (!success()) { return v; } // If the buf_ is in the outputs set, keep its statement intact. Otherwise, // remove it. if (v == producer_ && !outputs_.count(buf_)) { in_producer_ = true; producer_ = to(IRMutator::mutate(v)); if (!producer_) { // Producer statement for output buf should remain non-null in the fuser success_ = false; return v; } in_producer_ = false; return nullptr; } else { return IRMutator::mutate(v); } } // Any Random Intrinsics that were turned into vars must be inserted here. StmtPtr mutate(const BlockPtr& v) override { if (!success()) { return v; } std::vector stmts; for (const StmtPtr& stmt : *v) { StmtPtr stmt_new = stmt->accept_mutator(this); if (!stmt_new) { continue; } if (stmt == stmt_new) { stmt_new = Stmt::clone(stmt); } stmts.push_back(stmt_new); } return Block::make(stmts); } StmtPtr mutate(const ForPtr& v) override { if (!success()) { return v; } ForPtr res = to(IRMutator::mutate(v)); if (!res) { return nullptr; } // Find any random bindings that should be defined in this loops body. std::vector bindings_this_loop; VarPtr fv = v->var(); for (auto& pair : random_bindings_) { auto& index_var = pair.second; if (index_var.erase(fv)) { bindings_this_loop.push_back(pair.first); } } for (const auto& l : bindings_this_loop) { res->body()->prepend_stmt(l); random_bindings_.erase(l); } return res; } private: BufPtr buf_; StorePtr producer_; // Index Vars present in the producer. std::unordered_set index_vars_; std::vector producer_index_vars_; std::unordered_map inline_mapping_; // In the producer's scope - we need to bind any calls to rand(). bool in_producer_ = false; std::unordered_map> random_bindings_; std::unordered_set outputs_; bool success_ = true; }; static StmtPtr computeInlineImpl( const BufPtr& b, const StmtPtr& stmt, const std::unordered_set& output_bufs) { // If buf is used or defined in an ExternalCall, we cannot inline it auto buf_load_store_uses = findLoadOrStoreUses(stmt); if (!buf_load_store_uses.count(b)) { return nullptr; } for (auto& use : buf_load_store_uses.at(b)) { StmtPtr s = use.s; if (to(s) || to(s)) { return nullptr; } } // Find producers. StorePtr relevant_store{nullptr}; auto stores = NodeFinder::find(stmt); for (const auto& s : stores) { if (s->buf() == b) { auto reductions = NodeFinder::find(s); if (!reductions.empty()) { // Cannot inline a reduction computation return nullptr; } if (relevant_store != nullptr) { // Cannot inline Buf with multiple Tensors return nullptr; } relevant_store = s; } } if (!relevant_store) { // Cannot find a relevant store to inline a buf in the fuser return nullptr; } GRAPH_DEBUG("ComputeInline: Def: ", std::to_string(relevant_store)); FunctionInliner inliner(relevant_store, output_bufs); auto result = stmt->accept_mutator(&inliner); if (inliner.success()) { return result; } return nullptr; } bool LoopNest::computeInline(const BufPtr& b) { // Inlining may not always be successful. Since all mutations now happen // in-place, an unsuccessful inlining transformation might leave the IR // in an invalid state. To get around this problem, we clone the root stmt, // try inlining on the clone, and if it succeeds, we proceed to perform // inlining on the actual root stmt. This way the root stmt will always be // in a valid state. auto stmt_copy = Stmt::clone(root_stmt_); auto try_inline = computeInlineImpl(b, stmt_copy, output_bufs_); if (!try_inline) { return false; } root_stmt_ = computeInlineImpl(b, root_stmt_, output_bufs_); return true; } bool LoopNest::computeInline(const StmtPtr& s) { auto s_store = to(s); if (s_store == nullptr) { // Could not find buffer producer to inline return false; } return computeInline(s_store->buf()); } // inlining buffers with multiple uses can create duplicated work, which can // slow down cpu code generation but is enabled on gpu because it avoids // difficult synchronization logic across blocks. Inlining trivial reads does // not duplicate work void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { std::unordered_set bufs_to_inline; auto intermediate_bufs = getIntermediateBufs(); if (allow_duplicated_work) { bufs_to_inline.insert(intermediate_bufs.begin(), intermediate_bufs.end()); } else { auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); auto input_bufs = getInputBufs(); for (const auto& buf : intermediate_bufs) { TORCH_INTERNAL_ASSERT( buf_load_store_uses.count(buf), buildErrorMessage( "Could not find uses of buf '" + buf->name_hint() + "' in the fuser.")); std::vector& uses = buf_load_store_uses[buf]; auto stores = c10::filter( uses, [](const BufLoadOrStoreUse& use) { return use.isStore; }); // if the intermediate is the buffer formed from reading in the input // tensors, always inline, bc we are not duplicating any work // and avoiding an intermediary buffer if (stores.size() == 1) { if (auto store = to(stores[0].s)) { auto input_as_load = to(store->value()); if (input_as_load && input_bufs.count(input_as_load->buf())) { bufs_to_inline.insert(buf); continue; } } else { // If S is not a store, it must be an ExternalCall. TORCH_INTERNAL_ASSERT( to(stores[0].s) || to(stores[0].s), buildErrorMessage( "Expected stmt: " + std::to_string(stores[0].s) + "\nto be either a Store or an ExternalCall in the fuser.")); } } // all bufs will have at least one store (if they have > 1 they can't be // inlined anyway) size_t reads = uses.size() - 1; // if only one read, we can inline it without duplicating work if (reads <= 1) { bufs_to_inline.insert(buf); } } } if (allow_duplicated_work) { bufs_to_inline.insert(output_bufs_.begin(), output_bufs_.end()); } for (const auto& b : bufs_to_inline) { computeInline(b); } } // TODO: Unify with DepTracker class LoadOrStoreUseFinder : public IRVisitor { public: std::unordered_map> findUses( const StmtPtr& s) { uses_.clear(); s->accept(this); return uses_; } private: void visit(const StorePtr& v) override { if (stores_[v->buf()].insert(last_stmt_).second) { uses_[v->buf()].push_back({(StmtPtr)v, true}); } last_stmt_ = (StmtPtr)v; IRVisitor::visit(v); } void visit(const ExternalCallPtr& v) override { if (stores_[v->buf()].insert(last_stmt_).second) { uses_[v->buf()].push_back({(StmtPtr)v, true}); } last_stmt_ = (StmtPtr)v; for (const BufPtr& input_buf : v->buf_args()) { if (loads_[input_buf].insert(last_stmt_).second) { uses_[input_buf].push_back({last_stmt_, false}); } } IRVisitor::visit(v); } void visit(const ExternalCallWithAllocPtr& v) override { for (const auto& out_buf : v->buf_out_args()) { if (stores_[out_buf].insert(last_stmt_).second) { uses_[out_buf].push_back({(StmtPtr)v, true}); } } last_stmt_ = (StmtPtr)v; for (const auto& input_buf : v->buf_args()) { if (loads_[input_buf].insert(last_stmt_).second) { uses_[input_buf].push_back({last_stmt_, false}); } } IRVisitor::visit(v); } void visit(const LoadPtr& v) override { if (loads_[v->buf()].insert(last_stmt_).second) { uses_[v->buf()].push_back({last_stmt_, false}); } IRVisitor::visit(v); } StmtPtr last_stmt_ = nullptr; std::unordered_map> uses_; // Sets of loads and stores in order to keep the results unique std::unordered_map> loads_; std::unordered_map> stores_; }; std::unordered_map> findLoadOrStoreUses( const StmtPtr& s) { LoadOrStoreUseFinder uf; return uf.findUses(s); } class ContainedStmtsFinder : public IRVisitor { public: // Simply list all Stores and Block that are children of the given stmt const std::unordered_set& findContainedStmts(const StmtPtr& s) { contained_.clear(); s->accept(this); return contained_; } private: void visit(const StorePtr& v) override { contained_.insert((StmtPtr)v); IRVisitor::visit(v); } void visit(const ExternalCallPtr& v) override { contained_.insert((StmtPtr)v); IRVisitor::visit(v); } void visit(const ExternalCallWithAllocPtr& v) override { contained_.insert((StmtPtr)v); IRVisitor::visit(v); } void visit(const BlockPtr& v) override { contained_.insert((StmtPtr)v); IRVisitor::visit(v); } std::unordered_set contained_; }; class StmtDeleter : public IRMutator { public: StmtDeleter(const std::unordered_set& targets) : targets_(targets) {} private: StmtPtr mutate(const BlockPtr& v) override { std::vector stmts; for (const auto& s : v->stmts()) { if (targets_.count(s) == 0) { StmtPtr ns = s->accept_mutator(this); if (ns) { stmts.push_back(Stmt::clone(ns)); } } } return Block::make(stmts); } const std::unordered_set& targets_; }; void LoopNest::eliminateDeadStores() { using namespace analysis; MemDependencyChecker checker(getInputBufs(), getOutputBufs()); root_stmt_->accept(&checker); std::unordered_set deadStores; std::vector> outputAccesses; for (const auto& o : getOutputBufs()) { outputAccesses.push_back(checker.output(o)); } for (auto& info : checker.getHistory()) { if (!info->isWrite()) { continue; } bool found = false; for (auto& output : outputAccesses) { if (checker.dependsIndirectly(output, info)) { found = true; break; } } if (!found) { deadStores.insert(info->stmt()); } } StmtDeleter deleter(deadStores); root_stmt_ = root_stmt_->accept_mutator(&deleter); } void LoopNest::prepareForCodegen() { // Expand reduction ops. ReductionExpander reduceExpander; root_stmt_ = reduceExpander.expand(root_stmt_); root_stmt_ = FlattenIndexes(root_stmt_); } namespace { // This is extended from IRCloner instead of IRMutator because we want all // the rest of the IR nodes (the ones not touched directly) to be cloned. class IfThenElseReplacer : public IRCloner { public: IfThenElseReplacer(IfThenElsePtr to_replace, ExprPtr new_expr) : to_replace_(std::move(to_replace)), new_expr_(std::move(new_expr)) {} ExprPtr mutate(const IfThenElsePtr& i) override { if (i == to_replace_) { return new_expr_; } return IRCloner::mutate(i); } private: IfThenElsePtr to_replace_; ExprPtr new_expr_; }; // Check if the given condition is optimizable. // Specifically, this function looks for the following pattern: // "var < expr" // // If this pattern is found, then this function: // * sets `cond_var` to `var`, // * sets `compared_value` to `expr`, and // * returns true. bool isConditionOptimizable( const ExprPtr& condition, VarPtr* cond_var, ExprPtr* compared_value) { auto cs = to(condition); if (cs && cs->compare_select_op() == kLT) { auto var = to(cs->lhs()); if (var) { *cond_var = var; *compared_value = cs->rhs(); return true; } } return false; } // Checks if the given if-then-else expression is a conditional that is // generated from `aten::cat`. // // The expected format of conditionals is: // IfThenElse(var < val1? 1 : 0, // IfThenElse (var < val2? 1 : 0, // IfThenElse (var < val3? 1 : 0, // sub-expr1, // sub-expr2), // sub-expr3), // sub-expr4) // // If such a conditional is found, this function also sets: // * cond_var to the condition variable found in this expression. // * comp_values to the list of compared values in the condition expressions. // * sub_exprs to the list of sub-expressions that are the result of this // if-then-else expression. bool isConditionalFromCat( const IfThenElsePtr& ite, VarPtr* cond_var, std::vector* comp_values, std::vector* sub_exprs) { VarPtr var = nullptr; ExprPtr comp_value; if (isConditionOptimizable(ite->condition(), &var, &comp_value)) { if (*cond_var == nullptr) { *cond_var = var; } else if (*cond_var != var) { // Different condition variables found in nested if-then-else // expressions. Can not optimize such cases. return false; } auto true_ite = to(ite->true_value()); if (true_ite) { if (!isConditionalFromCat(true_ite, cond_var, comp_values, sub_exprs)) { return false; } } else { sub_exprs->push_back(ite->true_value()); } auto false_ite = to(ite->false_value()); if (false_ite) { return false; } comp_values->push_back(comp_value); sub_exprs->push_back(ite->false_value()); return true; } return false; } bool areConstantsAndSorted(const std::vector& comp_values) { std::vector comp_consts; comp_consts.reserve(comp_values.size()); for (const auto& c : comp_values) { if (!c->isConstant()) { return false; } comp_consts.push_back(immediateAs(c)); } return std::is_sorted(comp_consts.begin(), comp_consts.end()); } } // namespace bool LoopNest::optimizeConditionals() { // Consider every store in the root_stmt_ and try to optimize the // conditionals in that store. auto stores = NodeFinder::find(root_stmt_); std::unordered_set split_fors; for (const auto& store : stores) { VarPtr cond_var = nullptr; // `comp_values` represent the list of compared values that will be // collected as we check for the expected pattern. Since that will // only include the RHS of the conditions in the if-then-else expressions // we need to start with `0` which is the initial bound, given that we // only handle normalized loops (check for this is done below). std::vector comp_values; std::vector sub_exprs; auto ifthenelse_exprs = NodeFinder::find(store); if (ifthenelse_exprs.empty()) { continue; } // We only check if the first if-then-else expression in this store // corresponds to a conditional of the required format. If there are more // than one such conditional, optimizing them requires checking if the // conditions are exactly the same across them and handling all of them // together. Currently, this is not handled. if (!isConditionalFromCat( ifthenelse_exprs.front(), &cond_var, &comp_values, &sub_exprs)) { continue; } TORCH_INTERNAL_ASSERT( !comp_values.empty(), buildErrorMessage( "Expected at least one expression in optimizeConditional in the fuser.")); comp_values.insert(comp_values.begin(), immLike(comp_values[0], 0)); auto fors = getLoopStmtsFor(store); if (cond_var != fors.back()->var()) { // Currently, we only handle the case where the condition variable // is the same as the inner-most loop variable. // TODO: Handle all other cases here. // // In order to handle all other cases, the method `clone_and_replace` // called below to clone the body of the loop with a new store needs // to recursively handle cloning of the loops and other blocks it // contains. continue; } auto for_to_split = fors.back(); if (!LoopNest::isNormalized(for_to_split)) { // Do not optimize this conditional since the condition variable // refers to a loop that is not normalized. continue; } if (split_fors.count(for_to_split)) { // This loop has already been split while optimizing conditionals // earlier. // // Optimizing multiple conditionals that require splitting the same loop // is tricky. It requires checking if the conditions are exactly the same // across them and handling all of them together by splitting the loop // exactly once. // // Currently, this case is not supported. continue; } split_fors.insert(for_to_split); // `comp_values` needs to include the end bound, which is `for_to_split` // stop value. comp_values.push_back(for_to_split->stop()); // Check if all `comp_values` are constants and they are sorted. if (!areConstantsAndSorted(comp_values)) { continue; } // Remove all the if-then-else expressions from this store and create // one loop per sub-expression. std::vector split_loops; auto cond_to_replace = ifthenelse_exprs.front(); for (size_t i = 0; i < sub_exprs.size(); ++i) { IfThenElseReplacer ifthenelseReplacer(cond_to_replace, sub_exprs[i]); auto new_store = store->accept_mutator(&ifthenelseReplacer); auto new_for_body = for_to_split->body()->clone_and_replace(store, new_store); auto new_for = alloc( for_to_split->var(), comp_values[i], comp_values[i + 1], new_for_body); LoopNest::normalize(new_for); split_loops.push_back(new_for); } auto par = to(for_to_split->get_parent()); par->replace_stmt(for_to_split, alloc(split_loops)); } root_stmt_ = IRSimplifier::simplify(root_stmt_); return true; } void LoopNest::vectorizeInnerLoops() { std::vector innerLoops; std::vector worklist; // Find outer-most For loops if (ForPtr rootF = to(root_stmt_)) { worklist.push_back(rootF); } else if (BlockPtr body = to(root_stmt_)) { std::vector blocks = {body}; while (!blocks.empty()) { BlockPtr b = blocks.back(); blocks.pop_back(); for (const StmtPtr& s : *b) { if (const ForPtr& f = to(s)) { worklist.push_back(f); } else if (BlockPtr b2 = to(s)) { blocks.push_back(b2); } } } } // Traverse the For loop nest find inner-most loops, which are // vectorization candidates. while (!worklist.empty()) { ForPtr f = worklist.back(); worklist.pop_back(); bool containsSubLoops = false; if (BlockPtr body = to(f->body())) { for (const StmtPtr& s2 : *body) { if (const ForPtr& f2 = to(s2)) { containsSubLoops = true; worklist.push_back(f2); } } } if (!containsSubLoops) { innerLoops.push_back(f); } } // vectorize inner loops. for (const ForPtr& loop : innerLoops) { ForPtr split1; ForPtr tail1; static const int kBodyVectorWidth = 8; splitWithTail(loop, kBodyVectorWidth, &split1, &tail1); vectorize(split1); if (tail1) { ForPtr split2; ForPtr tail2; static const int kTailVectorWidth = 4; splitWithTail(tail1, kTailVectorWidth, &split2, &tail2); vectorize(split2); } } } void LoopNest::sliceHead( const ForPtr& f, int factor, ForPtr* head, ForPtr* tail) { if (intValue(f->start()) && intValue(f->stop())) { auto start_val = *intValue(f->start()); auto stop_val = *intValue(f->stop()); auto size_val = stop_val - start_val; if (factor >= size_val) { *head = f; *tail = nullptr; return; } } if (!f) { throw malformed_input("sliceHead attempted on null loop"); } BlockPtr p = to(f->get_parent()); if (!p) { throw malformed_input("sliceHead attempted on loop with no parent"); } ExprPtr head_end = alloc( alloc(f->start(), immLike(f->stop(), factor)), f->stop(), true); *head = alloc(f->var(), f->start(), head_end, Stmt::clone(f->body())); p->insert_stmt_before(*head, f); f->set_start(head_end); *tail = f; if (f->loop_options().is_gpu_block_index() || f->loop_options().is_gpu_thread_index()) { LoopNest::normalize(*tail); } } void LoopNest::sliceHead(const ForPtr& f, int factor) { ForPtr head, tail; sliceHead(f, factor, &head, &tail); } void LoopNest::sliceTail( const ForPtr& f, int factor, ForPtr* head, ForPtr* tail) { if (intValue(f->start()) && intValue(f->stop())) { auto start_val = *intValue(f->start()); auto stop_val = *intValue(f->stop()); auto size_val = stop_val - start_val; if (factor >= size_val) { *head = nullptr; *tail = f; return; } } if (!f) { throw malformed_input("sliceTail attempted on null loop"); } BlockPtr p = to(f->get_parent()); if (!p) { throw malformed_input("sliceTail attempted on loop with no parent"); } ExprPtr tail_start = alloc( f->start(), alloc(f->stop(), immLike(f->stop(), factor)), true); *tail = alloc(f->var(), tail_start, f->stop(), Stmt::clone(f->body())); p->insert_stmt_after(*tail, f); f->set_stop(tail_start); *head = f; if (f->loop_options().is_gpu_block_index() || f->loop_options().is_gpu_thread_index()) { LoopNest::normalize(*head); } } void LoopNest::sliceTail(const ForPtr& f, int factor) { ForPtr head, tail; sliceTail(f, factor, &head, &tail); } void LoopNest::splitWithTail(const ForPtr& f, int factor) { ForPtr inner, tail; splitWithTail(f, factor, &inner, &tail); } void LoopNest::splitWithTail( const ForPtr& f, int factor, ForPtr* inner, ForPtr* tail) { if (!f) { throw malformed_input("splitWithTail attempted on null loop"); } BlockPtr p = to(f->get_parent()); if (!p) { throw malformed_input("splitWithTail attempted on loop with no parent"); } // Normalize the loop to simplify start and stop bound computation normalize(f); bool tail_is_needed = true; if (intValue(f->start()) && intValue(f->stop())) { auto const start_val = *intValue(f->start()); auto const stop_val = *intValue(f->stop()); auto const size_val = stop_val - start_val; auto const tail_size = size_val % factor; if (tail_size == 0) { tail_is_needed = false; } } ExprPtr factor_expr = immLike(f->stop(), factor); ExprPtr size = alloc(f->stop(), f->start()); ExprPtr split_count = alloc
(size, factor_expr); ExprPtr tail_size = alloc(size, factor_expr); const std::string& loop_var_name = f->var()->name_hint(); Dtype loop_var_dtype = f->var()->dtype(); VarPtr i_inner = alloc(loop_var_name + "_inner", loop_var_dtype); VarPtr i_outer = alloc(loop_var_name + "_outer", loop_var_dtype); // x -> x.outer * inner.size + x.inner ExprPtr combined_index1 = alloc(alloc(i_outer, factor_expr), i_inner); if (tail_is_needed) { VarPtr i_tail = alloc(loop_var_name + "_tail", loop_var_dtype); // x -> x.tail + outer.size * inner.size ExprPtr combined_index2 = alloc(i_tail, alloc(split_count, factor_expr)); StmtPtr body_tail = SubstituteInClone(f->body(), {{f->var(), combined_index2}}); *tail = alloc(i_tail, immLike(tail_size, 0), tail_size, body_tail); p->insert_stmt_after(*tail, f); } else { *tail = nullptr; } StmtPtr body_inner = Substitute(f->removeBody(), {{f->var(), combined_index1}}); *inner = alloc(i_inner, immLike(factor_expr, 0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->set_var(i_outer); f->set_start(immLike(split_count, 0)); f->set_stop(split_count); f->set_body(*inner); } void LoopNest::splitWithMask(const ForPtr& f, int factor) { ForPtr inner; splitWithMask(f, factor, &inner); } void LoopNest::splitWithMask(const ForPtr& f, int factor, ForPtr* inner) { BlockPtr p = to(f->get_parent()); if (!p) { std::cerr << "Parent is not a Block!\n"; return; } bool tail_is_needed = true; ExprPtr start = IRSimplifier::simplify(f->start()); ExprPtr stop = IRSimplifier::simplify(f->stop()); if (start->isConstant() && stop->isConstant()) { auto start_val = *intValue(start); auto stop_val = *intValue(stop); auto size_val = stop_val - start_val; auto tail_size = size_val % factor; if (tail_size == 0) { tail_is_needed = false; } } auto factor_expr = immLike(f->stop(), factor); ExprPtr size = alloc(f->stop(), f->start()); // split_count = (size + factor - 1) / factor ExprPtr split_count = alloc
( alloc(alloc(size, factor_expr), immLike(size, 1)), factor_expr); const std::string& loop_var_name = f->var()->name_hint(); Dtype loop_var_dtype = f->var()->dtype(); VarPtr i_inner = alloc(loop_var_name + "_inner", loop_var_dtype); VarPtr i_outer = alloc(loop_var_name + "_outer", loop_var_dtype); // x -> x.outer * inner.size + x.inner ExprPtr combined_index = alloc(alloc(i_outer, factor_expr), i_inner); StmtPtr body_inner = f->removeBody(); // TODO: is it ok that we're doing it eagerly? In the other implementation we // are only materializing predicates at the last, lowering, step. if (tail_is_needed) { auto start = intValue(f->start()); if (!start || *start != 0) { throw unimplemented_lowering(); } ExprPtr predicate = CompareSelect::make(ExprHandle(f->var()), ExprHandle(f->stop()), kLT) .node(); body_inner = Cond::make(ExprHandle(predicate), body_inner, nullptr); } body_inner = Substitute(body_inner, {{f->var(), combined_index}}); *inner = alloc(i_inner, immLike(factor_expr, 0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->set_var(i_outer); f->set_start(immLike(split_count, 0)); f->set_stop(split_count); f->set_body(*inner); } std::vector LoopNest::distributeLoop( const ForPtr& loop, const std::unordered_set& pivots) { TORCH_INTERNAL_ASSERT( loop, buildErrorMessage( "Expected non-null loop in distributeLoop in the fuser.")); auto root = loop->get_parent(); if (root == nullptr) { throw malformed_input("Loop without parent: ", loop); } auto root_block = to(root); if (root_block == nullptr) { throw malformed_input( "Loop's parent must be a Block, instead found ", root); } // Extract bodies for all the loops after distribution. std::vector new_loop_bodies; auto new_loop_body = alloc(std::vector({})); while (!loop->body()->empty()) { auto s = loop->body()->front(); loop->body()->remove_stmt(s); new_loop_body->append_stmt(s); if (pivots.count(s)) { new_loop_bodies.push_back(new_loop_body); new_loop_body = alloc(std::vector({})); } } if (!new_loop_body->empty()) { new_loop_bodies.push_back(new_loop_body); } // The first loop body has to be in the original loop. loop->body()->splice(loop->body()->begin(), new_loop_bodies.front()); std::vector new_loops = {loop}; // Create loops for all the remaining blocks. // Add all the new loops to the parent block. for (size_t i = 1; i < new_loop_bodies.size(); ++i) { auto new_loop = loop->cloneWithNewBody(new_loop_bodies[i]); root_block->insert_stmt_after(new_loop, new_loops.back()); new_loops.push_back(new_loop); } return new_loops; } std::vector LoopNest::distributeLoop(const ForPtr& loop) { std::unordered_set stmtsInBlock( loop->body()->begin(), loop->body()->end()); return distributeLoop(loop, stmtsInBlock); } std::vector LoopNest::distributeLoopAndParents(const ForPtr& loop) { auto parentLoop = getParentLoop(loop); auto result = distributeLoop(loop); if (parentLoop) { return distributeLoopAndParents(parentLoop); } return result; } std::vector LoopNest::distributeLoopOverInnerLoops(const ForPtr& loop) { auto loops = NodeFinder::find(loop); std::unordered_set loopsSet(loops.begin(), loops.end()); return distributeLoop(loop, loopsSet); } std::vector LoopNest::distributeLoopAndParentsOverInnerLoops( const ForPtr& loop) { auto parentLoop = getParentLoop(loop); auto result = distributeLoopOverInnerLoops(loop); if (parentLoop) { return distributeLoopAndParentsOverInnerLoops(parentLoop); } return result; } static bool areEqual(const ExprPtr& expr1, const ExprPtr& expr2) { auto diff = IRSimplifier::simplify(alloc(expr1, expr2)); return diff->isConstant() && (immediateAs(diff) == 0); } static bool doesExprContainAnyVar( const ExprPtr& expr, const std::unordered_set& vars) { for (const auto& v : VarFinder::find(expr)) { if (vars.count(v)) { return true; } } return false; } // Returns true if the given list of indices refer to two accesses // that are loop-independent w.r.t. the given list of outer loop // variables. static bool areIndicesLoopIndependent( const std::vector& expr_list1, const std::vector& expr_list2, const std::unordered_set& outer_loop_vars) { if (expr_list1.size() != expr_list2.size()) { return false; } for (size_t i = 0; i < expr_list1.size(); ++i) { const auto& expr1 = expr_list1[i]; const auto& expr2 = expr_list2[i]; if (doesExprContainAnyVar(expr1, outer_loop_vars) || doesExprContainAnyVar(expr2, outer_loop_vars)) { if (!areEqual(expr1, expr2)) { return false; } } } return true; } bool LoopNest::hasLoopCarriedDependence(const ForPtr& loop) { analysis::MemDependencyChecker analyzer; loop->accept(&analyzer); std::unordered_set outer_loop_vars = {loop->var()}; auto outer_loops = LoopNest::getEnclosingLoopNest(loop); for (const auto& l : outer_loops) { outer_loop_vars.insert(l->var()); } // High-level algorithm to check if two accesses to a buffer, A and B, one of // which is a Store, result in a loop-carried dependence: // 1. For every pair of index expressions, Ai and Bi, that refer to a dim // of A and B, if one of the following conditions are satisfied: // a) Ai and Bi are equal (OR) // b) Both Ai and Bi do not contain any outer-loop variables // then, the dependence between A and B is a loop-independent // dependence. This is because, in the case of b), those index // expressions do not affect the ordering of accesses A and B. // 2. If condition 1) is not satisfied: // a) if the bounds on the accesses overlap, then this is a // loop-carried dependence. // b) if the bounds on the accesses do not overlap, then there is no // dependence. // // NOTE: Since we check for equality of index expressions whenever outer // loop variables are involved, this may incorrectly report some cases as // having a loop-carried dependence. It is impractical to handle all // possible cases here, so, we are being conservative and allow for // some false positives. While this will prevent some loop fusion // opportunities, that should be a small fraction of the cases that are // allowed. // // Implementation: // // For every pair of statements, S1 and S2, in the loop: // * Get the loads and stores in S1 and S2. // * For every store in S1 and load in S2 to the same buffer, if the index // expressions are not equal and there is an overlap in accesses, return // true to indicate a loop-carried dependence. // * For every load in S1 and store in S2 to the same buffer, if the index // expressions are not equal and there is an overlap in accesses, return // true to indicate a loop-carried dependence. // * For every store in S1 and store in S2 to the same buffer, if the index // expressions are not equal and there is an overlap in accesses, return // true to indicate a loop-carried dependence. for (auto it1 = loop->body()->begin(); it1 != loop->body()->end(); ++it1) { for (auto it2 = std::next(it1); it2 != loop->body()->end(); ++it2) { auto aStores = NodeFinder::find(*it1); auto aLoads = NodeFinder::find(*it1); auto bStores = NodeFinder::find(*it2); auto bLoads = NodeFinder::find(*it2); // ReadAfterWrite for (auto& aStore : aStores) { for (auto& bLoad : bLoads) { if (aStore->buf() == bLoad->buf()) { if (!areIndicesLoopIndependent( aStore->indices(), bLoad->indices(), outer_loop_vars)) { if (isOverlapping(analyzer, aStore, bLoad)) { return true; } } } } } // WriteAfterRead for (auto& bStore : bStores) { for (auto& aLoad : aLoads) { if (bStore->buf() == aLoad->buf()) { if (!areIndicesLoopIndependent( bStore->indices(), aLoad->indices(), outer_loop_vars)) { if (isOverlapping(analyzer, bStore, aLoad)) { return true; } } } } } // WriteAfterWrite for (auto& aStore : aStores) { for (auto& bStore : bStores) { if (aStore->buf() == bStore->buf()) { if (!areIndicesLoopIndependent( aStore->indices(), bStore->indices(), outer_loop_vars)) { if (isOverlapping(analyzer, aStore, bStore)) { return true; } } } } } } } return false; } bool LoopNest::unsafeFuseLoops( const std::vector& loops, ForPtr* fused) { if (loops.empty()) { return false; } if (loops.size() == 1) { *fused = loops.front(); return true; } // Check if all the loops have the same parent. auto root = loops.front()->get_parent(); for (const auto& l : loops) { auto par = l->get_parent(); if (par == nullptr) { return false; } if (par != root) { return false; } } auto root_block = to(root); if (root_block == nullptr) { return false; } // Currently, we only handle cases where there are no statements between // the given loops in their parents body. We can possibly relax this // constraint by allowing statements that do not affect the loops being // fused by performing some dependency analysis. TODO. auto it = root_block->begin(); for (; it != root_block->end(); ++it) { if (*it == loops.front()) { break; } } TORCH_INTERNAL_ASSERT( it != root_block->end(), buildErrorMessage( "Could not find the given loop in the root stmt in unsafeFuseLoop the fuser.")); for (const auto& l : loops) { if (*it != l) { return false; } ++it; } const auto& first_loop = loops.front(); // Fuse the loops by taking all the statements from the second loops // onwards and moving them into the first loop's body. // This way the final fused loop will be the same as the first loop. for (size_t i = 1; i < loops.size(); ++i) { auto body = to(SubstituteInClone( loops[i]->body(), {{loops[i]->var(), first_loop->var()}})); first_loop->body()->splice(first_loop->body()->end(), body); root_block->remove_stmt(loops[i]); } *fused = loops.front(); return true; } bool LoopNest::fuseLoops(const std::vector& loops, ForPtr* fused) { if (loops.empty()) { return false; } if (loops.size() == 1) { *fused = loops.front(); return true; } // Check if bounds are the same for all the loops. const auto& first_loop = loops.front(); auto first_loop_start = IRSimplifier::simplify(first_loop->start()); auto first_loop_stop = IRSimplifier::simplify(first_loop->stop()); for (size_t i = 1; i < loops.size(); ++i) { const auto& curr_loop = loops[i]; auto curr_loop_start = IRSimplifier::simplify(curr_loop->start()); auto curr_loop_stop = IRSimplifier::simplify(curr_loop->stop()); if (!areEqual(curr_loop_start, first_loop_start)) { return false; } if (!areEqual(curr_loop_stop, first_loop_stop)) { return false; } } // We need to check if fusing the loops results in a loop-carried dependence. // This check can be done only after the loops are fused into one. But if the // check is violated, we need to return the given loops in the original form. // So, we create a clone of all the loops, fuse them and check for this. std::vector loops_copy; loops_copy.reserve(loops.size()); BlockPtr parent = alloc(std::vector({})); for (auto& l : loops) { auto l_copy = Stmt::clone(l); loops_copy.push_back(to(l_copy)); parent->append_stmt(l_copy); } ForPtr fused_copy; bool ret = unsafeFuseLoops(loops_copy, &fused_copy); if (!ret || hasLoopCarriedDependence(fused_copy)) { return false; } // Now that all conditions are satisfied, we fuse the given loops. return unsafeFuseLoops(loops, fused); } ForPtr LoopNest::findOuterFor(ForPtr a, ForPtr b) { StmtPtr s = b; // guess b is the latter. while (s != nullptr) { if (s == a) { // yes, b is after a. return a; } s = s->get_parent(); } // check that the two are in the same loop nest. s = a; while (s != nullptr) { if (s == b) { // a is after b. return b; } s = s->get_parent(); } // a and b have no relationship. return nullptr; } void LoopNest::reorderAxis(const ForPtr& a, const ForPtr& b) { if (a == b) { // nothing to do. return; } // find inner and outer. ForPtr outer = findOuterFor(a, b); if (outer == nullptr) { throw std::runtime_error("Reordered a loop not in LoopNest"); } ForPtr inner = a == outer ? b : a; std::deque internal_axes; // Find relevant axes, store reversed. StmtPtr s = inner; while (s != outer) { if (const ForPtr& f = to(s)) { internal_axes.push_back(f); } s = s->get_parent(); } internal_axes.push_back(outer); BlockPtr root = to(outer->get_parent()); CHECK(root); // Do a shallow copy of the inner blocks. BlockPtr body = alloc(std::vector({})); body->splice(body->end(), inner->body()); const ForPtr& before{outer}; ForPtr after{nullptr}; ForPtr last = internal_axes.front(); StmtPtr newInner = body; s = inner; while (s != outer) { if (auto cond = to(s->get_parent())) { if (s == cond->true_stmt()) { newInner = cond->cloneWithNewBody(newInner); } else { // s is the false branch of Cond newInner = cond->cloneWithNewBodies( alloc(std::vector({})), newInner); } } s = s->get_parent(); } // This is the major complexity in loop reordering: handling statements not in // the straight line of the reorder. To handle this we partition the tree into // the section before the critical path and after the critical path. // // An example of this pattern is: // for i in .. // Statement A // for j in .. // Statement B // Statement C // // When reordering loop i and j we need to ensure that Statement A and C are // still both executed with the loop extents of i, and that the three // statements are not reordered (as much as possible). for (const auto& loop : internal_axes) { // If the inner loop had a component after the loop we must wrap it in a For // loop matching this level of the tree. if (after != nullptr) { after = loop->cloneWithNewBody(after); } bool pastMidpoint = false; bool hadBeforeStmts = false; for (auto I = loop->body()->begin(), E = loop->body()->end(); I != E;) { // Be careful not to invalidate the iterator. StmtPtr s = *(I++); if (s == last) { // This is the midpoint. loop->body()->remove_stmt(s); if (!hadBeforeStmts) { // If there were no existing statements this loop does not need to be // preserved and we can roll it into the above loop. last = loop; } pastMidpoint = true; } else if (pastMidpoint) { // Statements after the reordered path must be moved to a new tree after // the reordered statement has occurred to preserve ordering. loop->body()->remove_stmt(s); if (after == nullptr) { after = loop->cloneWithNewBody(s); } else { after->body()->append_stmt(s); } } else { // We can leave any statements before the reordered loop alone, so long // as we preserve the loop structure. hadBeforeStmts = true; } } } // now we can actually reorder the chosen axes. std::swap(internal_axes.front(), internal_axes.back()); // Create the reordered internals: for (const auto& loop : internal_axes) { newInner = loop->cloneWithNewBody(newInner); } // Append the new statements to the root of the tree. if (before->body()->nstmts() == 0) { // If the top level is now empty, eliminate it. root->replace_stmt(before, newInner); } else { root->insert_stmt_after(newInner, before); } if (after) { root->insert_stmt_after(after, newInner); } } static bool isTrivialPermutation(const std::vector& permutation) { for (size_t i = 0; i < permutation.size(); ++i) { if (permutation[i] != i) { return false; } } return true; } static bool isValidPermutation(std::vector permutation) { std::sort(permutation.begin(), permutation.end()); return isTrivialPermutation(permutation); } std::vector LoopNest::reorder( const std::vector& loops, const std::vector& permutation) { if (loops.size() != permutation.size()) { throw malformed_input("invalid permutation size"); } if (isTrivialPermutation(permutation)) { return loops; } if (!isValidPermutation(permutation)) { throw malformed_input("invalid permutation for reorder"); } if (loops.size() < 2) { return loops; } if (!areLoopsPerfectlyNested(loops)) { throw malformed_input("reorder is only allowed on perfectly nested loops"); } auto parent = to(loops.front()->get_parent()); if (parent == nullptr) { throw malformed_input("parent of the loops must be a Block"); } // Reorder the loops according to the permutation. std::vector result(loops.size()); for (size_t i = 0; i < loops.size(); ++i) { result[i] = loops[permutation[i]]; } // Remove the bodies from all the loops. auto innermost_body = loops.back()->removeBody(); // We use an empty block statement to replace the outermost loop // so that we know the position where the outermost reordered loop // is to be inserted. auto empty_block = alloc(std::vector({})); parent->replace_stmt(loops.front(), empty_block); for (size_t i = 1; i < loops.size(); ++i) { auto block = to(loops[i]->get_parent()); TORCH_INTERNAL_ASSERT( block, buildErrorMessage( "Expected parent stmt to be a non-null Block in reorder transformation the fuser.")); block->remove_stmt(loops[i]); } // Set the new bodies after reorder for all the loops. for (size_t i = 0; i < result.size() - 1; ++i) { result[i]->set_body(result[i + 1]); } result.back()->set_body(innermost_body); parent->replace_stmt(empty_block, result.front()); return result; } ForPtr LoopNest::getLoopAt(ForPtr root, const std::vector& indices) const { if (indices.empty()) { return root; } if (root == nullptr) { throw malformed_input("root loop is null"); } ForPtr curr = std::move(root); for (auto i : indices) { if (i < 0 || curr->body()->nstmts() <= static_cast(i)) { return nullptr; } std::list::iterator stmtp = curr->body()->begin(); std::advance(stmtp, i); curr = to(*stmtp); if (curr == nullptr) { return nullptr; } } return curr; } ForPtr LoopNest::tile( const ForPtr& x, const ForPtr& y, int x_factor, int y_factor) { auto parent = to(x->get_parent()); if (parent == nullptr) { throw malformed_input("parent of the loops must be a Block"); } if (!areLoopsPerfectlyNested({x, y})) { throw malformed_input("two loops must be perfectly nested"); } // Split x, y axes by x_factor and y_factor ForPtr yi, ytail; splitWithTail(y, y_factor, &yi, &ytail); ForPtr xi, xtail; splitWithTail(x, x_factor, &xi, &xtail); // Distribute xi over yo and ytail so we can manipulate the loop order of {xo, // xi, yo, yi} auto loops = distributeLoop(xi); // For {xi, yo, yi}, reorder the axes to be yo, xi, yi xi = loops.front(); ForPtr yo = to(xi->body()->stmts().front()); CHECK(yo); reorder({xi, yo}, {1, 0}); // For {xi, ytail}, reorder the axes to be ytail, xi if (loops.size() == 2) { xi = loops.back(); ytail = to(xi->body()->stmts().front()); CHECK(ytail); reorder({xi, ytail}, {1, 0}); } return xtail; } bool LoopNest::areLoopsPerfectlyNested(const std::vector& loops) { if (loops.size() < 2) { return true; } for (size_t i = 0; i < loops.size() - 1; ++i) { auto loop_body = loops[i]->body(); if (loop_body->nstmts() != 1 || loop_body->front() != loops[i + 1]) { return false; } } return true; } void LoopNest::fullUnroll(const ForPtr& f, StmtPtr* unrolled) { BlockPtr p = to(f->get_parent()); if (!f) { throw malformed_input("unroll attempted on null loop"); } else if (!p) { throw malformed_input("unroll attempted on loop with no parent"); } auto start_expr = IRSimplifier::simplify(f->start()); auto stop_expr = IRSimplifier::simplify(f->stop()); if (!start_expr->isConstant()) { throw std::runtime_error("Can't unroll due to non-constant loop start!"); } if (!stop_expr->isConstant()) { throw std::runtime_error("Can't unroll due to non-constant loop stop!"); } std::vector unrolled_stmts; int start_val = immediateAs(start_expr); int stop_val = immediateAs(stop_expr); for (int current = start_val; current < stop_val; ++current) { for (const auto& stmt : f->body()->stmts()) { unrolled_stmts.push_back(SubstituteInClone( stmt, {{f->var(), getImmediateByType(f->var()->dtype(), current)}})); } } *unrolled = alloc(unrolled_stmts); *unrolled = IRSimplifier::simplify(*unrolled); p->replace_stmt(f, *unrolled); } void LoopNest::fullUnroll(const ForPtr& f) { StmtPtr unrolled; fullUnroll(f, &unrolled); } void LoopNest::unroll(const ForPtr& f, int factor, ForPtr* tail) { if (factor < 2) { return; } ForPtr inner; splitWithTail(f, factor, &inner, tail); fullUnroll(inner); } void LoopNest::unroll(const ForPtr& f, int factor) { ForPtr tail; unroll(f, factor, &tail); } bool LoopNest::isNormalized(const ForPtr& f) { if (f->start()->isConstant()) { return immediateAs(f->start()) == 0; } return false; } bool LoopNest::normalize(const ForPtr& f) { if (!f) { throw malformed_input("normalize attempted on null loop"); } if (isNormalized(f)) { // No need to normalize anymore here. return false; } auto for_body_normalized = Substitute( f->body(), {{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}}); f->set_body(IRSimplifier::simplify(for_body_normalized)); f->set_stop(IRSimplifier::simplify(alloc(f->stop(), f->start()))); f->set_start(immLike(f->stop(), 0)); return true; } // This function expects that there are 'num' loops perfectly nested within // and including 'f'. std::vector LoopNest::getLoopStmtsInLoopNest( const ForPtr& f, size_t num) { std::vector loops(num); ForPtr curr_for = f; loops[0] = curr_for; for (size_t i = 1; i < num; ++i) { TORCH_INTERNAL_ASSERT( curr_for->body()->nstmts() == 1, buildErrorMessage("Expected a single stmt in the loop body.")); curr_for = to(curr_for->body()->front()); TORCH_INTERNAL_ASSERT( curr_for, buildErrorMessage("Expected the only child stmt to be a For loop.")); loops[i] = curr_for; } return loops; } bool LoopNest::flatten(const std::vector& loops, ForPtr* flattened) { if (loops.empty()) { throw malformed_input("flatten attempted on empty set of loops"); } BlockPtr p = to(loops[0]->get_parent()); if (!p) { throw malformed_input("flatten attempted on loops with no parent"); } if (loops.size() == 1) { // This loop nest is already flattened. *flattened = loops[0]; return false; } // Check if all the loops correspond to a perfect loopnest: // * every loop except the inner-most should have only one stmt, the For. // Do not flatten, otherwise. // This check also ensures we do not flatten reduction loops. for (size_t i = 0; i < loops.size() - 1; ++i) { if ((loops[i]->body()->nstmts() != 1) || (loops[i]->body()->front() != loops[i + 1])) { return false; } } // Normalize the loops before flattening. // We need to normalize them from inner-most to outer because once the outer // loop is normalized, the given pointers to inner loops point to old code. // For the same reason, we can't store the normalized inner loops until after // the outer-most loop is normalized. for (size_t i = 0; i < loops.size(); ++i) { size_t idx = loops.size() - i - 1; LoopNest::normalize(loops[idx]); } // 'normalized' points to the outer-most loop in the normalized loopnest. // Collect all the normalized loops. auto normalized_loops = getLoopStmtsInLoopNest(loops.front(), loops.size()); auto flat_var = alloc( normalized_loops[0]->var()->name_hint() + "_flat", normalized_loops[0]->var()->dtype()); VarMapping var_mapping; ExprPtr stop = immLike(flat_var, 1); for (size_t i = 0; i < normalized_loops.size(); ++i) { size_t idx = normalized_loops.size() - i - 1; auto curr_loop = normalized_loops[idx]; ExprPtr div = alloc
(flat_var, stop); ExprPtr sub_expr = idx == 0 ? div : alloc(div, curr_loop->stop()); var_mapping.emplace_back(curr_loop->var(), sub_expr); stop = alloc(curr_loop->stop(), stop); } auto flattened_body = Substitute(normalized_loops.back()->removeBody(), var_mapping); normalized_loops.front()->set_var(flat_var); normalized_loops.front()->set_start(immLike(stop, 0)); normalized_loops.front()->set_stop(stop); normalized_loops.front()->set_body(flattened_body); *flattened = normalized_loops.front(); return true; } bool LoopNest::flatten(const std::vector& loops) { ForPtr flattened; return flatten(loops, &flattened); } void LoopNest::compressBuffer(const BufPtr& buf, const StmtPtr& stmt) { // Loop iterations in NNC IR do not follow sequential semantics by default. // In other words, the iterations of the loops could be executed in any // random order without affecting correctness. This constraint in turn // implies that there can’t be any *inter-iteration* dependences // (or *loop-carried* dependences) in NNC loops. So, any NNC IR with such // dependences is considered invalid. // // Given the constraint above, for any pair of accesses to a buffer (where // at least one of the access is a write), the accesses must be // loop-independent on the innermost loop containing the accesses as well as // all the loops above it. So, any dimension that uses only those loop // variables to access the given buffer could be optimized away. // // Algorithm: // * Find all the accesses to the given buf. (A) // * Find the parent common to all accesses in A. (P) // * Collect all the loops above P. (L) // * Collect all the loop variables corresponding to L. (LV) // * For every access a in A: // * For the index I in every dimension of a: // * If the variables in I are all in LV, mark this dimension // for deletion. // * For every dimension that is marked for deletion in ALL accesses in A: // * Update the buffer to set the size of that dimension to 1. // * Update all accesses in A to set the index in that dimension to 0. auto writes = WritesToBuf::find(stmt, buf); auto reads = StmtsReadingBuf::find(stmt, buf); // Find the parent common to all the buffer accesses. BlockPtr parent = to(writes.front()->get_parent()); TORCH_INTERNAL_ASSERT( parent, buildErrorMessage( "Expected parent stmt to be a non-null block in compressBuffer in the fuser.")); for (const auto& w : writes) { parent = Block::getSharedParent(parent, w); } for (const auto& r : reads) { parent = Block::getSharedParent(parent, r); } // Collect all the loops that are above the common parent. auto loops = LoopNest::getEnclosingLoopNest(parent); std::unordered_set loop_vars; for (const auto& l : loops) { loop_vars.insert(l->var()); } // TODO: Need to handle other Stmts / Exprs that read / write buffers. auto stores = NodeFinder::find(stmt); auto loads = NodeFinder::find(stmt); // Vector to indicate which dimensions could be compressed away. std::vector dims(buf->dims().size(), true); auto check_indices = [&](const std::vector& indices) { TORCH_INTERNAL_ASSERT( indices.size() == dims.size(), buildErrorMessage( "Expected ranks to match in compressBuffer in the fuser.")); for (size_t i = 0; i < indices.size(); ++i) { auto index_vars = NodeFinder::find(indices[i]); for (const auto& iv : index_vars) { if (loop_vars.count(iv) == 0) { // A variable in this index is not in loop_vars. // This implies that this dimension cannot be optimized away. dims[i] = false; break; } } } }; for (const auto& s : stores) { if (s->buf() == buf) { check_indices(s->indices()); } } for (const auto& l : loads) { if (l->buf() == buf) { check_indices(l->indices()); } } bool any_dim_to_compress = false; for (auto d : dims) { any_dim_to_compress |= d; } if (!any_dim_to_compress) { return; } // Compress buffer by removing the marked dims. std::vector new_dims(buf->dims()); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { new_dims[i] = immLike(buf->dims()[i], 1); } } buf->set_dims(new_dims); // Modify all access to reflect the removed dims. auto get_new_indices = [&](const std::vector& indices) { TORCH_INTERNAL_ASSERT( indices.size() == dims.size(), buildErrorMessage( "Expected ranks to match in compressBuffer in the fuser.")); std::vector new_indices(indices); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { new_indices[i] = immLike(indices[i], 0); } } return new_indices; }; for (const auto& s : stores) { if (s->buf() == buf) { s->set_indices(get_new_indices(s->indices())); } } for (const auto& l : loads) { if (l->buf() == buf) { l->set_indices(get_new_indices(l->indices())); } } } void LoopNest::compressAllBuffers(const StmtPtr& stmt) { for (const auto& buf : BufFinder::find(stmt)) { compressBuffer(buf, stmt); } } std::vector LoopNest::getLoopStmtsFor(const Tensor& t) const { StmtPtr cur_stmt = getLoopBodyFor(t); return getLoopStmtsFor(cur_stmt); } std::vector LoopNest::getLoopStmtsFor(const BufPtr& buf) const { StmtPtr cur_stmt = getLoopBodyFor(buf); return getLoopStmtsFor(cur_stmt); } std::vector LoopNest::getLoopStmtsFor(StmtPtr s) const { std::vector result; while (s) { if (auto loop = to(s)) { result.push_back(loop); } s = s->get_parent(); } std::reverse(result.begin(), result.end()); return result; } StmtPtr LoopNest::getLoopBodyFor(const Tensor& t) const { return getLoopBodyFor(t.buf()); } StmtPtr LoopNest::getLoopBodyFor(BufPtr buf) const { auto writes = WritesToBuf::find(root_stmt_, std::move(buf)); // special case for reduction Tensors, ignore the initializer if it's the only // op: if (writes.size() == 2) { if (StorePtr s = to(writes.back())) { if (ReduceOpPtr r = to(s->value())) { return (StmtPtr)s; } } } StmtPtr res = nullptr; for (const auto& s : writes) { if (!res) { res = s; continue; } res = Block::getSharedParent(res, s); } return (StmtPtr)res; } ForPtr LoopNest::getParentLoop(const StmtPtr& st) { if (st == nullptr) { return nullptr; } auto par = st->get_parent(); if (auto f = to(par)) { return f; } return getParentLoop(par); } std::vector LoopNest::getEnclosingLoopNest(const StmtPtr& st) { std::vector loops; auto f = getParentLoop(st); while (f) { loops.push_back(f); f = getParentLoop(f); } std::reverse(loops.begin(), loops.end()); return loops; } std::vector LoopNest::getAllWritesToBuf(BufPtr buf) const { return WritesToBuf::find(root_stmt_, std::move(buf)); } std::vector LoopNest::getAllInnermostLoopsWritingToBuf( BufPtr buf) const { auto writes = getAllWritesToBuf(std::move(buf)); std::vector innermost_loops; innermost_loops.reserve(writes.size()); for (const auto& w : writes) { innermost_loops.push_back(LoopNest::getParentLoop(w)); } return innermost_loops; } std::vector> LoopNest::getAllLoopNestsWritingToBuf( BufPtr buf) const { auto writes = getAllWritesToBuf(std::move(buf)); std::vector> loopnests; loopnests.reserve(writes.size()); for (const auto& w : writes) { loopnests.emplace_back(LoopNest::getEnclosingLoopNest(w)); } return loopnests; } StmtPtr LoopNest::simplify() { root_stmt_ = IRSimplifier::simplify(root_stmt_); return root_stmt_; } StmtPtr FlattenIndexes(const StmtPtr& s) { IndexFlattener idx_flattener; return idx_flattener.flatten(s); } // Auxiliary class for rewriting we're doing in `compute_at`. See // LoopNest::computeAt for more details. class LoopComputeAtRewriter : public IRMutator { public: LoopComputeAtRewriter( BufPtr buf, BufPtr new_buf, std::vector offsets) : buf_(std::move(buf)), new_buf_(std::move(new_buf)), offsets_(std::move(offsets)) {} private: BufPtr buf_; BufPtr new_buf_; std::vector offsets_; ExprPtr mutate(const LoadPtr& v) override { if (v->buf() != buf_) { return v; } std::vector new_indices(v->indices().size()); for (const auto i : c10::irange(v->indices().size())) { new_indices[i] = IRSimplifier::simplify(alloc(v->indices()[i], offsets_[i])); } return alloc(v->dtype(), new_buf_, new_indices); } }; static StorePtr getStoreStmtOfProducer(const StmtPtr& s) { if (StorePtr st = to(s)) { return st; } if (BlockPtr b = to(s)) { for (const StmtPtr& ss : *b) { if (StorePtr st = to(ss)) { return st; } } } return nullptr; } static std::vector getOuterLoopIndexes(StmtPtr s) { std::vector res; StmtPtr cur = std::move(s); while (cur) { if (auto l = to(cur)) { res.push_back(l->var()); } cur = cur->get_parent(); } return res; } class CacheReplacer : public IRMutator { public: CacheReplacer(BufPtr buffer, BufPtr cache, std::vector& offsets) : buf_(std::move(buffer)), cache_(std::move(cache)), offsets_(offsets) {} private: ExprPtr mutate(const LoadPtr& v) override { BufPtr buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } // Map indices to call-parameters. std::vector newIndices; TORCH_INTERNAL_ASSERT( offsets_.size() == v->indices().size(), buildErrorMessage( "Expected ranks to match in CacheReplacer in the fuser.")); for (size_t i = 0; i < v->indices().size(); ++i) { ExprPtr index = v->indices()[i]->accept_mutator(this); ExprPtr offset = offsets_[i]; ExprPtr sub = IRSimplifier::simplify(alloc(index, offset)); newIndices.push_back(sub); } v->set_buf(cache_); v->set_indices(newIndices); return v; } StmtPtr mutate(const StorePtr& v) override { BufPtr buf = v->buf(); if (buf != buf_) { return IRMutator::mutate(v); } ExprPtr newValue = v->value()->accept_mutator(this); // Map indices to call-parameters. std::vector newIndices; TORCH_INTERNAL_ASSERT( offsets_.size() == v->indices().size(), buildErrorMessage( "Expected ranks to match in CacheReplacer in the fuser.")); for (size_t i = 0; i < v->indices().size(); ++i) { ExprPtr index = v->indices()[i]->accept_mutator(this); ExprPtr offset = offsets_[i]; ExprPtr sub = IRSimplifier::simplify(alloc(index, offset)); newIndices.push_back(sub); } v->set_buf(cache_); v->set_indices(newIndices); v->set_value(newValue); return v; } BufPtr buf_; BufPtr cache_; std::vector& offsets_; }; LoopNest::AccessResult LoopNest::cacheAccesses( const BufPtr& producer, const std::string& name, const StmtPtr& consumer) { ReduceOpPtr reduceOp{nullptr}; auto stores = NodeFinder::find(consumer); for (const auto& store : stores) { if (auto ro = to(store->value())) { if (store->buf() != producer) { continue; } if (reduceOp) { throw std::runtime_error( "can only cache accesses used by at most a single reduceOp"); return {nullptr, nullptr}; } reduceOp = ro; } } // Check bounds but don't care about AccessKind. auto consumer_bounds_info = inferBounds(consumer, false); auto bounds_it = consumer_bounds_info.find(producer); if (bounds_it == consumer_bounds_info.end()) { throw std::runtime_error("consumer does not use the Tensor produced"); return {nullptr, nullptr}; } TORCH_INTERNAL_ASSERT( bounds_it->second.size() == 1, buildErrorMessage( "Unexpected number of bound info entries in cacheAccesses in the fuser.")); TensorAccessBoundsInfo& info = bounds_it->second[0]; bool hasReads = info.kind == kLoad || info.kind == kMutate; bool hasWrites = info.kind == kStore || info.kind == kMutate; std::vector var_names = {"i", "j", "k", "l", "m", "n", "o", "p"}; std::vector tmp_dims; std::vector new_loop_vars; std::vector new_loop_vars_expr; // Determine the size of the cache, and create a loop var for each dimension. for (size_t i = 0; i < info.start.size(); ++i) { ExprPtr dim = IRSimplifier::simplify(alloc( alloc(info.stop[i], info.start[i]), immLike(info.stop[i], 1))); tmp_dims.push_back(dim); new_loop_vars.push_back( alloc(var_names[i % var_names.size()], info.stop[i]->dtype())); new_loop_vars_expr.push_back(new_loop_vars[i]); } // Create the var. BufPtr tmp_buf = alloc(alloc(name, kHandle), tmp_dims, producer->dtype()); // determine the offsets for calls into the cache based off the loop start of // each axis. std::vector tmp_params; for (size_t i = 0; i < new_loop_vars.size(); ++i) { tmp_params.push_back(alloc(new_loop_vars[i], info.start[i])); } // Replace accesses to the producer in the consumer with the cache. CacheReplacer replacer(producer, tmp_buf, info.start); consumer->accept_mutator(&replacer); // replace the old consumer with the replaced consumer. BlockPtr consumer_block = to(consumer); BlockPtr parent_block = to(consumer->get_parent()); // if the consumer is a block, we should mutate it in place. bool is_block = consumer_block != nullptr; // If there's a reduction and we are operating on the reduce axis, we need to // initialize the cache with 0s. Also, we can't just write the result straight // back to the original buffer, since after parallelism the writes will race. // Instead we need to create a new ReduceOp. bool on_reduce_axis = false; if (reduceOp) { std::set reduce_args( reduceOp->reduce_args().begin(), reduceOp->reduce_args().end()); std::set enclosing_vars; for (const auto& enclosing_for_stmt : NodeFinder::find(consumer)) { enclosing_vars.insert(enclosing_for_stmt->var()); } for (const auto& reduce_arg : reduce_args) { if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) { on_reduce_axis = true; } } } if (reduceOp && on_reduce_axis) { // reduceOp means we had both loads and stores. // Init cache to 0. StmtPtr tmp_init = alloc( tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_init = alloc( new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_init); } if (is_block) { consumer_block->prepend_stmt(tmp_init); } else { parent_block->insert_stmt_before(tmp_init, consumer); } // Reduce back to the original buffer: StmtPtr tmp_store = alloc( producer, tmp_params, reduceOp->reducer()( producer, alloc(tmp_buf, new_loop_vars_expr), tmp_params, {})); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = alloc( new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } if (is_block) { consumer_block->append_stmt(tmp_store); } else { parent_block->insert_stmt_after(tmp_store, consumer); } return std::make_pair(tmp_buf, consumer); } if (hasReads) { // Fill the cache with values from the consumer. StmtPtr tmp_store = alloc( tmp_buf, new_loop_vars_expr, alloc(producer, tmp_params)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = alloc( new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } if (is_block) { consumer_block->prepend_stmt(tmp_store); } else { parent_block->insert_stmt_before(tmp_store, consumer); } } if (hasWrites) { // sync the cache back to the producer buf. StmtPtr tmp_store = alloc( producer, tmp_params, alloc(tmp_buf, new_loop_vars_expr)); for (int64_t i = static_cast(new_loop_vars.size()) - 1; i >= 0; --i) { tmp_store = alloc( new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } if (is_block) { consumer_block->append_stmt(tmp_store); } else { parent_block->insert_stmt_after(tmp_store, consumer); } } return std::make_pair(tmp_buf, consumer); } /* * WHAT COMPUTE_AT DOES * ==================== * * Suppose we have two loops: * * for i in 0..100: * for j in 0..200: * A[i,j] = sin(i*j) * for i in 0..100: * for j in 0..199: * B[i,j] = A[i,j] + A[i, j+1] * * If we compute these loops as is, we would have to allocate two buffers: * 100x200 for A and 100x199 for B. To decrease the memory usage one can use * compute_inline primitive, which would result in the following: * * for i in 0..100: * for j in 0..199: * B[i,j] = sin(i*j) + sin(i*(j+1)) * * We now need only one buffer - 100x199 for B. However, we're now doing some * redundant computations: we're calling `sin` twice as much as in the first * version. * * Ultimately, we nede to choose at what point we prefer to compute values of * A[i,j] - we can do it in the very beginning for the entire buffer A (the * first option) or compute it on the fly when we compute B (the second option). * There are also options in between those two: we can compute a part of B which * is required for a computation of part of B, e.g. for a single row of B. The * code would then look like: * * for i in 0..100: * for j in 0..200: * A[j] = sin(i*j) * for j in 0..199: * B[i,j] = A[j] + A[j+1] * * In this case we're only using 1x200 for A, and we're avoiding redundant * computations. * * The purpose of `compute_at` is to achieve exactly this transformation. * * compute_at requires to specify What to compute and Where to compute: in our * example we would call compute_at(What=`A[i,j] = sin(i*j)`, Where=`for i in * 0..100`). * * More info about compute_at could be found in Halide's tutorials: * https://halide-lang.org/tutorials/tutorial_lesson_08_scheduling_2.html * * HOW COMPUTE_AT WORKS * ==================== * * The most important part of compute_at is bounds inference: we need to figure * out what part of the used tensors we need to compute when we move the * computation to a new scope. In the example above, we need bounds inference to * tell us that in order to compute A at each iteration of the outer loop, we * need to compute A within indices [i:i+1,0:200]. * * This info allows us to conclude that we need a temp buffer of size 1x200. * * Once this is known we need to insert statements for allocation and freeing * the temporary buffer and copy the original computation to fill the temp * buffer with proper values. When we copy the computation we also must rewrite * indices used in it: old indices are referring to the old loop and are not * valid in the new loop. * * To easier follow the logic, let's examine an example. Suppose we start from * the following loop nest: * for py in 0..100: * for px in 0..100: * producer[py,px] = py*px * for cy in 0..100: * for cx in 0..100: * consumer[cy,cx] = producer[cy,cx] * * And then we're running `compute_at(producer, cy)`. * * What we would like to get is the following loop nest: * for py in 0..100: * for px in 0..100: * producer[py,px] = py*px * for cy in 0..100: * Allocate(temp, {1, 100}) * for ty in 0..1: * for tx in 0..100: * temp[ty,tx] = (ty+cy)*(tx+0) * for cx in 0..100: * consumer[cy,cx] = temp[0,cx] * Free(temp) * * NB: this loop nest can and should be simplified (e.g. the producer loop can * be removed since its result is no longer used), but this clean-up * optimization is performed separately (currently, not performed at all). * * If we examine the final loop nest, we can identify that the following steps * needs to be performed: * - Bounds inference needs to tell us that we need a 1x100 buffer for temp. * - Allocate and Free statements for this buffer need to be inserted to the * loop. * - A new loop-nest should be inserted to the loop CY for computing `temp` * and it should replicate the loopnest of producer (PY,PX loops). The indices * in the loop body need to be offset by (cy, 0) - the offsets come from * bounds inference too. * - The computation of `consumer` needs to be rewritten so that it uses * `temp` instead of `producer`. The indices in the corresponding accesses * also need to be offset. */ void LoopNest::computeAt(const StmtPtr& s, const ForPtr& f) { StorePtr st = getStoreStmtOfProducer(s); if (!st) { return; } // Infer bounds info for all accesses that we make in the loop auto loop_bounds_info = inferBounds(f->body()); // bounds_it holds bounds info for the store we're trying to move to // the loop. If its result isn't accessed in the loop at all - do nothing and // exit early. auto bounds_it = loop_bounds_info.find(st->buf()); if (bounds_it == loop_bounds_info.end()) { return; } // Compute dimensions of the temp buffer we would need to allocate std::vector dims = getBoundExtents(bounds_it->second); // TODO: Use name-hint of the producer instead of "temp" BufPtr temp_buf = alloc("temp", dims, st->value()->dtype()); // Generate index variables for 'temp' std::vector temp_indices(dims.size()); for (const auto i : c10::irange(dims.size())) { // TODO: Use name-hint of the producer indices instead of 'idx' temp_indices[i] = alloc(std::string("idx") + std::to_string(i), dims[i]->dtype()); } // Prepare substitute rules for constructing the temp statement from the prod // statement // TODO: Instead of going up the loop nest we should go through the indices in // the original tensor expression. The loops in the nest might've been // modified (e.g. split or merged) so that the loop indices no longer // correspond to the indices of the original expression and even their number // might be different. In that case, the loop below would crash. std::vector prod_indices = getOuterLoopIndexes(s); std::vector> rewrite_indices_map; std::vector offsets; for (const TensorAccessBoundsInfo& p : bounds_it->second) { for (const auto i : c10::irange(p.start.size())) { if (offsets.size() <= i) { offsets.push_back(p.start[i]); } else { offsets[i] = IRSimplifier::simplify(alloc(offsets[i], p.start[i], true)); } } } for (const auto i : c10::irange(prod_indices.size())) { rewrite_indices_map.emplace_back( prod_indices[i], alloc(temp_indices[i], offsets[i])); } // Construct the temp statement StmtPtr bd = alloc( temp_buf, temp_indices, SubstituteInClone(st->value(), rewrite_indices_map)); // Construct the loop nest for the temp computation for (const auto i : c10::irange(dims.size())) { // We're creating loops from innermost to outermost, so we need to access // dimensions in reversed order. size_t dim_idx = dims.size() - 1 - i; bd = alloc( to(temp_indices[dim_idx]), immLike(dims[dim_idx], 0), dims[dim_idx], bd); } // Add constructed stmts to the consumer loop f->body()->prepend_stmt(bd); // Rewrite accesses to producer in consumer with accesses to temp LoopComputeAtRewriter lr(st->buf(), temp_buf, offsets); StmtPtr new_f = f->accept_mutator(&lr); if (f != new_f) { BlockPtr bb = to(f->get_parent()); bb->replace_stmt(f, new_f); } } class RfactorStoreRewriter : public IRMutator { public: RfactorStoreRewriter( BufPtr old_buf, const std::vector& old_indices, BufPtr new_buf, VarPtr reduction_var) : old_buf_(std::move(old_buf)), old_indices_(old_indices), new_buf_(std::move(new_buf)), reduction_var_(std::move(reduction_var)), new_indices_(old_indices) { new_indices_.push_back(reduction_var_); } ExprPtr mutate(const LoadPtr& v) override { if (v->buf() != old_buf_) { return IRMutator::mutate(v); } TORCH_INTERNAL_ASSERT( old_indices_.size() == v->indices().size(), buildErrorMessage( "Expected ranks to match in RfactorStoreRewriter in the fuser.")); bool equal_indices = true; for (size_t i = 0; i < v->indices().size(); ++i) { if (!exprEquals(v->indices()[i], old_indices_[i])) { equal_indices = false; break; } } if (!equal_indices) { return IRMutator::mutate(v); } return alloc(new_buf_, new_indices_); } ExprPtr mutate(const ReduceOpPtr& v) override { ExprPtr body_new = v->body()->accept_mutator(this); std::vector new_reduce_args; for (const auto& r : v->reduce_args()) { if (r != reduction_var_) { new_reduce_args.push_back(r); } } return alloc(body_new, new_reduce_args, v->reducer()); } StmtPtr mutate(const StorePtr& v) override { if (v->buf() != old_buf_) { return IRMutator::mutate(v); } TORCH_INTERNAL_ASSERT( old_indices_.size() == v->indices().size(), buildErrorMessage( "Expected ranks to match in RfactorStoreRewriter in the fuser.")); bool equal_indices = true; for (size_t i = 0; i < v->indices().size(); ++i) { if (!exprEquals(v->indices()[i], old_indices_[i])) { equal_indices = false; break; } } if (!equal_indices) { return IRMutator::mutate(v); } ExprPtr new_value = v->value()->accept_mutator(this); return alloc(new_buf_, new_indices_, new_value); } private: BufPtr old_buf_; const std::vector& old_indices_; BufPtr new_buf_; VarPtr reduction_var_; std::vector new_indices_; }; bool LoopNest::rfactor(const StmtPtr& st, const ForPtr& target_for) { BufPtr tmp_buf = nullptr; return rfactor(st, target_for, &tmp_buf); } bool LoopNest::rfactor( const StmtPtr& st, const ForPtr& outer_reduction_for, BufPtr* rfac_buf_ptr) { StorePtr reduction_store = to(st); ReduceOpPtr reduce_op = to(reduction_store->value()); if (!reduce_op) { // Not a reduction store return false; } auto orig_buf = reduction_store->buf(); auto orig_buf_indices = reduction_store->indices(); VarPtr reduction_var = outer_reduction_for->var(); std::set reduce_args = { reduce_op->reduce_args().begin(), reduce_op->reduce_args().end()}; if (reduce_args.size() < 2) { // Not enough reduction axis to do rfactor return false; } // Verify that outer_reduction_for is a perfect loop nest with all loops being // reductions StmtPtr cur = outer_reduction_for; while (ForPtr cur_for = to(cur)) { if (!reduce_args.count(cur_for->var())) { // output axis inside outer_reduction_for are not allowed return false; } reduce_args.erase(cur_for->var()); BlockPtr b = cur_for->body(); if (b->nstmts() != 1) { return false; } cur = b->stmts().front(); } if (cur != st) { // The reduction store is not a single stmt in the innermost loop - bail in // that case return false; } if (!reduce_args.empty()) { // This is not the outermost reduction axis return false; } // assert: reduce_axis match loop vars from outer_reduction_for and inside // assert: no other stmts in outer_reduction_for or its child loops std::vector rfac_dims = orig_buf->dims(); ExprPtr extra_dim = IRSimplifier::simplify( alloc(outer_reduction_for->stop(), outer_reduction_for->start())); rfac_dims.push_back(extra_dim); ExprPtr rfac_init = alloc(reduce_op->dtype(), reduce_op->reducer().initializer()); *rfac_buf_ptr = alloc( orig_buf->name_hint() + "_rfac", rfac_dims, reduce_op->dtype(), rfac_init); BufPtr rfac_buf = *rfac_buf_ptr; // Rewrite the original reduction store to use the temporary rfac buffer: // 1) X[*indexes] --> T[*indexes + {reduction_var}] // 2) reduce_axis -= {reduction_var} RfactorStoreRewriter rfac_rewriter( orig_buf, orig_buf_indices, rfac_buf, reduction_var); to(st->get_parent()) ->replace_stmt(st, st->accept_mutator(&rfac_rewriter)); // Insert a store for the final reduction over the temp buffer into the // original buffer: // X[*indexes] = ReduceOp(X[*indexes] + T[*indexes + {reduction_var}], // reduce_axis={reduction_var}) BlockPtr b = outer_reduction_for->body(); TORCH_INTERNAL_ASSERT( b->nstmts() == 1, buildErrorMessage( "Expected to have a single stmt in the block in rfactor transformation in the fuser.")); StmtPtr first_reduction_loop = b->stmts().front(); auto rfac_buf_indices = orig_buf_indices; rfac_buf_indices.emplace_back(reduction_var); ExprPtr final_reduce_load = alloc(rfac_buf, rfac_buf_indices); outer_reduction_for->body()->insert_stmt_after( alloc( orig_buf, orig_buf_indices, reduce_op->reducer()( orig_buf, final_reduce_load, orig_buf_indices, {reduction_var})), first_reduction_loop); // Insert an initialization store for the temp buffer: // T[a,b,c] = init outer_reduction_for->body()->insert_stmt_before( alloc(rfac_buf, rfac_buf_indices, rfac_init), first_reduction_loop); return true; } } // namespace torch::jit::tensorexpr