#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { namespace { void removeCatNodeFromGraph(Node* n) { TORCH_INTERNAL_ASSERT(n->kind() == aten::cat); auto inp_list = n->input(0); GRAPH_UPDATE("Deleting\n", *n); n->destroy(); if (!inp_list->hasUses()) { GRAPH_UPDATE("Deleting\n", *inp_list->node()); inp_list->node()->destroy(); } } bool equal(at::ArrayRef list1, at::ArrayRef list2) { return list1.size() == list2.size() && std::equal(list1.begin(), list1.end(), list2.begin()); } class ConcatCommonInputsEliminator { public: explicit ConcatCommonInputsEliminator(std::shared_ptr graph) : graph_(std::move(graph)) {} bool run() { handleBlock(graph_->block()); return postprocess(); } private: void handleBlock(Block* block) { for (auto node : block->nodes()) { if (node->kind() == prim::VarConcat) { handleCat(node); } for (Block* block : node->blocks()) { handleBlock(block); } } } void handleCat(Node* node) { GRAPH_DEBUG("Considering cat node for CSE opt: ", node); auto curr_all_inputs = node->inputs(); auto curr_tensor_inputs = curr_all_inputs.slice(0, curr_all_inputs.size() - 1); auto curr_dim = curr_all_inputs.back(); // Save the input list and the current cat node, so that this can be // used for subsequent cat nodes, unless there are writes to this cat // node. When there are writes to this cat node, its output does not // represent this concatenated list beyond the writes. Currently, we do // not perform such fine-grained analysis. So, if there are any writes to // the output, we do not use this cat node for optimization here. if (!getOrCreateAliasDb()->hasWriters(node->output())) { concated_outputs_.insert(node); } if (curr_tensor_inputs.size() <= 2) { // The case when concat has 2 input tensors could only be optimized if // there is another concat of the exact same 2 input tensors. That case // is expected to be handled by the CSE pass. return; } // Now, we check if the first N-1 elements in %inputs appeared in any of // the previous cat ops. // // Example: // %11 = prim::VarConcat(%0, %1, ) // ... // %13 = prim::VarConcat(%0, %1, %2, ) // first 2 inputs same as %11 // ... // = %13 ... // Use %13 // // After CSE opt: // %11 = prim::VarConcat(%0, %1, ) // ... // %14 = prim::VarConcat(%11, %2, ) // Replace first 2 inputs // // with %11 // ... // = %14 ... // Replace use of %13 with %14 auto curr_tensor_inputs_prefix = curr_tensor_inputs.slice(0, curr_tensor_inputs.size() - 1); for (const auto& prev : concated_outputs_) { auto prev_all_inputs = prev->inputs(); auto prev_tensor_inputs = prev_all_inputs.slice(0, prev_all_inputs.size() - 1); auto prev_dim = prev_all_inputs.back(); if (equal(curr_tensor_inputs_prefix, prev_tensor_inputs) && curr_dim == prev_dim) { if (!node->isDominatedBy(prev)) { // We can't use the previous concatenated output if it does not // dominate the current concat node. continue; } std::vector new_inputs = { prev->output(), curr_tensor_inputs.back(), curr_dim}; auto new_concat = node->owningGraph()->create(prim::VarConcat, new_inputs); new_concat->output()->setType(node->output()->type()); concats_to_replace_[node] = new_concat; return; } } // Now, we check if the last N-1 elements in %inputs appeared in any of // the previous cat ops. // // Example: // %10 = prim::ListConstruct(%1, %2) // %11 = aten::cat(%10, ...) // ... // %12 = prim::ListConstruct(%0, %1, %2) // last 2 inputs same as %11 // %13 = aten::cat(%12, ...) // ... // = %13 ... // Use %13 // // After CSE opt: // %10 = prim::ListConstruct(%0, %1) // %11 = aten::cat(%10, ...) // ... // %12 = prim::ListConstruct(%0, %11) // Replace last 2 inputs with %11 // %13 = aten::cat(%12, ...) // ... // = %13 ... // Use %13 auto curr_tensor_inputs_suffix = curr_tensor_inputs.slice(1, curr_tensor_inputs.size() - 1); for (const auto& prev : concated_outputs_) { auto prev_all_inputs = prev->inputs(); auto prev_tensor_inputs = prev_all_inputs.slice(0, prev_all_inputs.size() - 1); auto prev_dim = prev_all_inputs.back(); if (equal(curr_tensor_inputs_suffix, prev_tensor_inputs) && curr_dim == prev_dim) { if (!node->isDominatedBy(prev)) { // We can't use the previous concatenated list if it does not // dominate the current list. continue; } std::vector new_inputs = { curr_tensor_inputs.front(), prev->output(), curr_dim}; auto new_concat = node->owningGraph()->create(prim::VarConcat, new_inputs); new_concat->output()->setType(node->output()->type()); concats_to_replace_[node] = new_concat; return; } } // Do we need to handle other cases where N-2 or lesser elements from // %inputs appear in any of the previous cat ops? // TODO. } bool postprocess() { // Replace the list nodes that have been marked. bool changed = false; for (auto it : concats_to_replace_) { auto curr_node = it.first; auto new_node = it.second; GRAPH_UPDATE("Inserting\n", *new_node, "before\n", *curr_node); new_node->insertBefore(curr_node); GRAPH_UPDATE("Replacing uses of\n", *curr_node, "with\n", *new_node); curr_node->output()->replaceAllUsesWith(new_node->output()); GRAPH_UPDATE("Deleting\n", *curr_node); curr_node->destroy(); changed = true; } return changed; } AliasDb* getOrCreateAliasDb() { if (!aliasDb_) { aliasDb_ = std::make_unique(graph_); } return aliasDb_.get(); } std::shared_ptr graph_; std::unique_ptr aliasDb_ = nullptr; std::unordered_set concated_outputs_; std::unordered_map concats_to_replace_; }; } // namespace bool EliminateConcatCommonInputs(const std::shared_ptr& graph) { GRAPH_DUMP("Before eliminating Concat common inputs", graph); bool changed = ConcatCommonInputsEliminator(graph).run(); if (changed) { GRAPH_DUMP("After eliminating Concat common inputs", graph); } return changed; } namespace { class ConcatExpander { public: explicit ConcatExpander(std::shared_ptr graph) : graph_(std::move(graph)) {} void run() { handleBlock(graph_->block()); cleanupExpandedCatOps(); GRAPH_DUMP("Before reusing copy buffers: ", graph_); reuseBuffersInCopies(); } private: void handleBlock(Block* block) { for (auto node : block->nodes()) { if (node->kind() == aten::cat) { expandCat(node); } for (Block* block : node->blocks()) { handleBlock(block); } } } // Expand cat node into multiple copy nodes. // // Example: // %2 = aten::clamp(%0, ...) // %3 = aten::clamp(%1, ...) // %10 = prim::ListConstruct(%2, %3) // %11 = aten::cat(%10, ...) // ... // = %11 ... // Use %11 // // After expanding cat: // %2 = aten::clamp(%0, ...) // %3 = aten::clamp(%1, ...) // %20 = aten::empty(...) // cat output buffer // %21 = aten::slice(%20, ...) // slice for %2 // %22 = aten::copy_(%21, %2) // copy %2 // %23 = aten::slice(%20, ...) // slice for %3 // %24 = aten::copy_(%23, %3) // copy %3 // ... // = %20 ... // Use %20 in place of %11 void expandCat(Node* node) { GRAPH_DEBUG("Considering cat node for expansion: ", node); // Do not optimize cat nodes whose inputs are mutated in the graph. // TODO: Improve this by checking if it is mutated in the graph region // where this optimization is applied. if (getOrCreateAliasDb()->hasWriters(node->input(0))) { return; } if (node->input(0)->node()->kind() != prim::ListConstruct) { // Unknown form of input to `cat` op. return; } if (!allShapesAreKnown(node)) { // Can't expand when shapes are not known for the `cat` op. return; } for (auto cat_inp : node->input(0)->node()->inputs()) { if (!shapeIsKnown(cat_inp)) { // Can't expand when shapes of the inputs to `cat` are not known. return; } } // TODO: Handle non-contiguous Tensors. // For example, how to handle the cases where the inputs are all channels // last? auto maybe_cat_dim = constant_as(node->input(1)); if (!maybe_cat_dim) { // Can't expand when cat dimension is not a constant. return; } auto cat_dim_value = maybe_cat_dim.value(); auto cat_dim = node->input(1); // Set the insertion point to the current `cat` node. WithInsertPoint guard(node); auto none = graph_->insertConstant(IValue()); auto one = graph_->insertConstant(1); // Insert the constants needed for the `cat` output buffer size. auto tensortype = node->output()->type()->expect(); TORCH_INTERNAL_ASSERT(tensortype); auto tensortype_sizes = tensortype->sizes(); std::vector cat_out_size; for (size_t i = 0; i < tensortype_sizes.size(); ++i) { cat_out_size.push_back(graph_->insertConstant(tensortype_sizes[i])); } // Create a list of int for `cat` output buffer size. auto cat_out_size_list = graph_->createList(IntType::get(), cat_out_size); cat_out_size_list->insertBefore(node); // Create an empty buffer to be used as `cat` output buffer. // TODO: Handle tensors with different dtype, layout, device, memory // format, etc. auto cat_out_empty = graph_->create( aten::empty, {cat_out_size_list->output(), none, none, none, none, none}); cat_out_empty->insertBefore(node); // For every input to this `cat` node: // * Create a slice of `cat` output buffer. auto cat_out_value = cat_out_empty->output(); auto cat_inp_list = node->input(0)->node(); int64_t start_idx = 0; auto start = graph_->insertConstant(start_idx); for (auto cat_inp : cat_inp_list->inputs()) { // Create a slice of the cat output buffer that correspond to // this input size and position in the output. auto cat_inp_tensor_type = dynamic_cast(cat_inp->type().get()); TORCH_INTERNAL_ASSERT(cat_inp_tensor_type); TORCH_INTERNAL_ASSERT(cat_inp_tensor_type->dim()); auto cat_inp_tensortype_sizes = cat_inp_tensor_type->sizes(); auto end_idx = start_idx + *cat_inp_tensortype_sizes[cat_dim_value]; auto end = graph_->insertConstant(end_idx); auto slice = graph_->create( aten::slice, {cat_out_value, cat_dim, start, end, one}); GRAPH_UPDATE("Inserting\n", *slice, "before\n", *node); slice->insertBefore(node); slices_added_.push_back(slice); // Insert a copy from this input to the output slice. auto copy = graph_->create(aten::copy_, {slice->output(), cat_inp}); GRAPH_UPDATE("Inserting\n", *copy, "before\n", *node); copy->insertBefore(node); copies_added_.push_back(copy); start_idx = end_idx; start = end; } // Replace the uses of `cat` node with the cat output buffer. replace_uses_with_[node->output()] = cat_out_value; nodes_to_remove_.insert(node); } bool shapeIsKnown(Value* v) { if (v->type()->cast()) { if (!v->isCompleteTensor()) { return false; } if (*v->type()->castRaw()->dim() == 0) { return false; } } return true; } bool allShapesAreKnown(Node* node) { // TODO: Relax the checks to support dynamic shapes for (Value* input : node->inputs()) { if (!shapeIsKnown(input)) { return false; } } for (Value* output : node->outputs()) { if (!shapeIsKnown(output)) { return false; } } return true; } void cleanupExpandedCatOps() { for (auto it : replace_uses_with_) { GRAPH_UPDATE( "Replacing uses of\n", *it.first->node(), "with\n", *it.second->node()); it.first->replaceAllUsesWith(it.second); } for (auto n : nodes_to_remove_) { removeCatNodeFromGraph(n); } } void moveBefore(Node* node, Node* before) { // In order to move a node before another node, we need to move // all the nodes it depends on as well. for (auto inp : node->inputs()) { moveBefore(inp->node(), before); } node->moveBefore(before); } // Reuse buffers in copies wherever possible. // // For example, consider the following sequence of ops: // %10 = prim::ListConstruct(%0, %1) // %11 = aten::cat(%10, ...) // ... // %12 = prim::ListConstruct(%11, %2) // Uses the result of above cat // %13 = aten::cat(%12, ...) // // Once these cat ops are expanded into copies, we will have two buffers; one // for %11 and another for %13. This can be optimized by using only one // buffer. We can only have the buffer that represents %13 and use a view // (slice) of that one as the buffer for %11. // // If any of the copies added earlier has `aten::empty` as its source, // those cases can be replaced with a single buffer. // // Example: // %20 = aten::empty(...) // cat.1 output buffer // %21 = aten::slice(%20, ...) // %22 = aten::copy_(%21, %2) // %23 = aten::slice(%20, ...) // %24 = aten::copy_(%23, %3) // ... // %30 = aten::empty(...) // cat.2 output buffer // %31 = aten::slice(%30, ...) // %32 = aten::copy_(%31, %20) // src of copy is aten::empty // // so, we reuse this buffer above // %33 = aten::slice(%30, ...) // %34 = aten::copy_(%33, %4) // // After reusing copy buffers: // %30 = aten::empty(...) // cat.2 output buffer // %31 = aten::slice(%30, ...) // move %31 and inputs before %20 // %21 = aten::slice(%31, ...) // use %31 in place of %20 // %22 = aten::copy_(%21, %2) // %23 = aten::slice(%31, ...) // use %31 in place of %20 // %24 = aten::copy_(%23, %3) // ... // ... // copy to %31 is now removed // %33 = aten::slice(%30, ...) // %34 = aten::copy_(%33, %4) void reuseBuffersInCopies() { for (auto copy : copies_added_) { auto src = copy->input(1); auto dst = copy->input(0); if (src->node()->kind() != aten::empty) { continue; } // Move the destination node before the source. GRAPH_UPDATE("Moving\n", *dst->node(), "before\n", *src->node()); moveBefore(dst->node(), src->node()); GRAPH_UPDATE("Replacing\n", *src->node(), "with\n", *dst->node()); src->replaceAllUsesWith(dst); GRAPH_UPDATE("Deleting\n", *src->node()); src->node()->destroy(); GRAPH_UPDATE("Deleting\n", *copy); copy->destroy(); } } AliasDb* getOrCreateAliasDb() { if (!aliasDb_) { aliasDb_ = std::make_unique(graph_); } return aliasDb_.get(); } std::shared_ptr graph_; std::unique_ptr aliasDb_ = nullptr; std::unordered_set nodes_to_remove_; std::unordered_map replace_uses_with_; std::vector copies_added_; std::vector slices_added_; }; } // namespace void ExpandConcatAndEliminateRedundancy(const std::shared_ptr& graph) { ConcatExpander(graph).run(); GRAPH_DUMP("After expanding Concat and eliminating redundancy", graph); } namespace { size_t determineUsageIdx(Value* value, Node* user) { const auto idx = std::find(user->inputs().begin(), user->inputs().end(), value) - user->inputs().begin(); using c10::ssize; TORCH_CHECK(idx != ssize(user->inputs())); return idx; } std::vector getConcatInputs(Node* concat) { TORCH_CHECK(concat->kind() == aten::cat); auto* list = concat->input(0); auto* list_construct = list->node(); TORCH_CHECK(list_construct->kind() == prim::ListConstruct); return list_construct->inputs().vec(); } class ConcatCombiner { public: explicit ConcatCombiner(std::shared_ptr graph) : graph_(std::move(graph)), aliasDb_(graph_) {} bool run() { collectOptimizableConcats(); bool changed = combineConcats(); if (changed) { EliminateDeadCode(graph_); } return changed; } private: // Given a concat node, see if it can be optimized with another. // If so, add a CombinablePair to combinable_concats_. void handleConcat(Node* node) { auto* list = node->input(0); auto* list_node = list->node(); const auto dim_opt = toIValue(node->input(1)); // We need to be able to determine dim statically to match it with another // concat. if (!dim_opt || !dim_opt->isInt()) { return; } const auto dim = dim_opt->toInt(); // Check that the input of this node is an unmodified list construct if (list_node->kind() != prim::ListConstruct || !aliasDb_.couldMoveBeforeTopologically(list_node, node)) { return; } // Check that the only output of this node is used in an unmodified list // construct. const auto& concat_uses = node->output()->uses(); if (concat_uses.size() != 1) { return; } auto* next_list = concat_uses[0].user; if (next_list->kind() != prim::ListConstruct) { return; } const auto& next_list_uses = next_list->output()->uses(); if (next_list_uses.size() != 1) { return; } auto* next_concat = next_list_uses[0].user; if (next_concat->kind() == aten::cat) { // Dimension must be determined statically and match the one we've already // seen. const auto next_dim_opt = toIValue(next_concat->input(1)); if (!next_dim_opt || next_dim_opt->toInt() != dim) { return; } combinable_concats_.emplace_back( node, next_concat, determineUsageIdx(node->output(), next_list)); } } void collectOptimizableConcats() { DepthFirstGraphNodeIterator graph_it(graph_); for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) { if (node->kind() == aten::cat) { handleConcat(node); } } } Node* createListConstruct(const std::deque& inputs) { auto* output = graph_->create(prim::ListConstruct); for (auto* v : inputs) { output->addInput(v); } return output; } using ListConstructInputs = std::shared_ptr>; // Construct a map (concat node) -> (new list inputs for this node). // std::deque is used so we can do O(1) insertions to the front. std::unordered_map getListConstructInputs() { std::unordered_map cur_list_construct_inputs; for (const auto& combinable : combinable_concats_) { // Combine the list inputs of first_concat with those of second_concat const auto& inputs_to_add = getConcatInputs(combinable.second_concat); auto it = cur_list_construct_inputs.find(combinable.first_concat); std::shared_ptr> cur_list; if (it != cur_list_construct_inputs.end()) { cur_list = it->second; // We're moving all inputs to second_concat. cur_list_construct_inputs.erase(combinable.first_concat); } else { cur_list = std::make_shared>(); } cur_list_construct_inputs.emplace(combinable.second_concat, cur_list); // If cur_list is not empty, it's guaranteed to already contain all of // first_concat's inputs. if (cur_list->empty()) { const auto& starting_values = getConcatInputs(combinable.first_concat); cur_list->insert( cur_list->end(), starting_values.begin(), starting_values.end()); } cur_list->insert( cur_list->begin(), inputs_to_add.begin(), inputs_to_add.begin() + combinable.idx); cur_list->insert( cur_list->end(), inputs_to_add.begin() + combinable.idx + 1, inputs_to_add.end()); } return cur_list_construct_inputs; } bool combineConcats() { if (combinable_concats_.empty()) { return false; } auto list_construct_inputs = getListConstructInputs(); for (const auto& node_and_new_list : list_construct_inputs) { auto* node = node_and_new_list.first; auto& inputs = node_and_new_list.second; auto* new_list_construct = createListConstruct(*inputs); auto* old_list_construct = node->input(0)->node(); new_list_construct->output()->setType( old_list_construct->output()->type()); new_list_construct->insertBefore(node); old_list_construct->replaceAllUsesWith(new_list_construct); } return true; } // Represents an optimizable pair of concat nodes. // - first_concat must appear before second_concat // - idx is the index where first_concat's inputs must be inserted into // second_concat's new inputs. // Example: // %inputs.1 = prim::ListConstruct(%0, %0) // %concat.1 = aten::cat(%inputs.1, %dim) // %inputs.2 = prim::ListConstruct(%1, %concat.1, %1) // %concat.2 = aten::cat(%inputs.2, %dim) // -> first_concat = &concat.1, second_concat = &concat.2, idx = 1 struct CombinableConcat { CombinableConcat(Node* a, Node* b, size_t i) : first_concat(a), second_concat(b), idx(i) {} Node* first_concat; Node* second_concat; size_t idx; }; std::vector combinable_concats_; std::shared_ptr graph_; AliasDb aliasDb_; }; } // namespace bool CombineConcats(const std::shared_ptr& graph) { bool changed = ConcatCombiner(graph).run(); GRAPH_DUMP("After combining concats", graph); return changed; } } // namespace torch::jit