mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Follows #132604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132753 Approved by: https://github.com/Skylion007
701 lines
22 KiB
C++
701 lines
22 KiB
C++
#include <torch/csrc/jit/passes/concat_opt.h>
|
|
|
|
#include <algorithm>
|
|
#include <deque>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
#include <c10/util/ssize.h>
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/ir/named_value.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/constant_pooling.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/remove_mutation.h>
|
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
namespace {
|
|
|
|
void removeCatNodeFromGraph(Node* n) {
|
|
TORCH_INTERNAL_ASSERT(n->kind() == aten::cat);
|
|
auto inp_list = n->input(0);
|
|
GRAPH_UPDATE("Deleting\n", *n);
|
|
n->destroy();
|
|
if (!inp_list->hasUses()) {
|
|
GRAPH_UPDATE("Deleting\n", *inp_list->node());
|
|
inp_list->node()->destroy();
|
|
}
|
|
}
|
|
|
|
bool equal(at::ArrayRef<Value*> list1, at::ArrayRef<Value*> list2) {
|
|
return list1.size() == list2.size() &&
|
|
std::equal(list1.begin(), list1.end(), list2.begin());
|
|
}
|
|
|
|
class ConcatCommonInputsEliminator {
|
|
public:
|
|
explicit ConcatCommonInputsEliminator(std::shared_ptr<Graph> graph)
|
|
: graph_(std::move(graph)) {}
|
|
|
|
bool run() {
|
|
handleBlock(graph_->block());
|
|
return postprocess();
|
|
}
|
|
|
|
private:
|
|
void handleBlock(Block* block) {
|
|
for (auto node : block->nodes()) {
|
|
if (node->kind() == prim::VarConcat) {
|
|
handleCat(node);
|
|
}
|
|
for (Block* block : node->blocks()) {
|
|
handleBlock(block);
|
|
}
|
|
}
|
|
}
|
|
|
|
void handleCat(Node* node) {
|
|
GRAPH_DEBUG("Considering cat node for CSE opt: ", node);
|
|
|
|
auto curr_all_inputs = node->inputs();
|
|
auto curr_tensor_inputs =
|
|
curr_all_inputs.slice(0, curr_all_inputs.size() - 1);
|
|
auto curr_dim = curr_all_inputs.back();
|
|
|
|
// Save the input list and the current cat node, so that this can be
|
|
// used for subsequent cat nodes, unless there are writes to this cat
|
|
// node. When there are writes to this cat node, its output does not
|
|
// represent this concatenated list beyond the writes. Currently, we do
|
|
// not perform such fine-grained analysis. So, if there are any writes to
|
|
// the output, we do not use this cat node for optimization here.
|
|
if (!getOrCreateAliasDb()->hasWriters(node->output())) {
|
|
concated_outputs_.insert(node);
|
|
}
|
|
|
|
if (curr_tensor_inputs.size() <= 2) {
|
|
// The case when concat has 2 input tensors could only be optimized if
|
|
// there is another concat of the exact same 2 input tensors. That case
|
|
// is expected to be handled by the CSE pass.
|
|
return;
|
|
}
|
|
|
|
// Now, we check if the first N-1 elements in %inputs appeared in any of
|
|
// the previous cat ops.
|
|
//
|
|
// Example:
|
|
// %11 = prim::VarConcat(%0, %1, <dim>)
|
|
// ...
|
|
// %13 = prim::VarConcat(%0, %1, %2, <dim>) // first 2 inputs same as %11
|
|
// ...
|
|
// = %13 ... // Use %13
|
|
//
|
|
// After CSE opt:
|
|
// %11 = prim::VarConcat(%0, %1, <dim>)
|
|
// ...
|
|
// %14 = prim::VarConcat(%11, %2, <dim>) // Replace first 2 inputs
|
|
// // with %11
|
|
// ...
|
|
// = %14 ... // Replace use of %13 with %14
|
|
|
|
auto curr_tensor_inputs_prefix =
|
|
curr_tensor_inputs.slice(0, curr_tensor_inputs.size() - 1);
|
|
for (const auto& prev : concated_outputs_) {
|
|
auto prev_all_inputs = prev->inputs();
|
|
auto prev_tensor_inputs =
|
|
prev_all_inputs.slice(0, prev_all_inputs.size() - 1);
|
|
auto prev_dim = prev_all_inputs.back();
|
|
if (equal(curr_tensor_inputs_prefix, prev_tensor_inputs) &&
|
|
curr_dim == prev_dim) {
|
|
if (!node->isDominatedBy(prev)) {
|
|
// We can't use the previous concatenated output if it does not
|
|
// dominate the current concat node.
|
|
continue;
|
|
}
|
|
|
|
std::vector<Value*> new_inputs = {
|
|
prev->output(), curr_tensor_inputs.back(), curr_dim};
|
|
auto new_concat =
|
|
node->owningGraph()->create(prim::VarConcat, new_inputs);
|
|
new_concat->output()->setType(node->output()->type());
|
|
concats_to_replace_[node] = new_concat;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Now, we check if the last N-1 elements in %inputs appeared in any of
|
|
// the previous cat ops.
|
|
//
|
|
// Example:
|
|
// %10 = prim::ListConstruct(%1, %2)
|
|
// %11 = aten::cat(%10, ...)
|
|
// ...
|
|
// %12 = prim::ListConstruct(%0, %1, %2) // last 2 inputs same as %11
|
|
// %13 = aten::cat(%12, ...)
|
|
// ...
|
|
// = %13 ... // Use %13
|
|
//
|
|
// After CSE opt:
|
|
// %10 = prim::ListConstruct(%0, %1)
|
|
// %11 = aten::cat(%10, ...)
|
|
// ...
|
|
// %12 = prim::ListConstruct(%0, %11) // Replace last 2 inputs with %11
|
|
// %13 = aten::cat(%12, ...)
|
|
// ...
|
|
// = %13 ... // Use %13
|
|
auto curr_tensor_inputs_suffix =
|
|
curr_tensor_inputs.slice(1, curr_tensor_inputs.size() - 1);
|
|
for (const auto& prev : concated_outputs_) {
|
|
auto prev_all_inputs = prev->inputs();
|
|
auto prev_tensor_inputs =
|
|
prev_all_inputs.slice(0, prev_all_inputs.size() - 1);
|
|
auto prev_dim = prev_all_inputs.back();
|
|
if (equal(curr_tensor_inputs_suffix, prev_tensor_inputs) &&
|
|
curr_dim == prev_dim) {
|
|
if (!node->isDominatedBy(prev)) {
|
|
// We can't use the previous concatenated list if it does not
|
|
// dominate the current list.
|
|
continue;
|
|
}
|
|
|
|
std::vector<Value*> new_inputs = {
|
|
curr_tensor_inputs.front(), prev->output(), curr_dim};
|
|
auto new_concat =
|
|
node->owningGraph()->create(prim::VarConcat, new_inputs);
|
|
new_concat->output()->setType(node->output()->type());
|
|
concats_to_replace_[node] = new_concat;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Do we need to handle other cases where N-2 or lesser elements from
|
|
// %inputs appear in any of the previous cat ops?
|
|
// TODO.
|
|
}
|
|
|
|
bool postprocess() {
|
|
// Replace the list nodes that have been marked.
|
|
bool changed = false;
|
|
for (auto it : concats_to_replace_) {
|
|
auto curr_node = it.first;
|
|
auto new_node = it.second;
|
|
GRAPH_UPDATE("Inserting\n", *new_node, "before\n", *curr_node);
|
|
new_node->insertBefore(curr_node);
|
|
GRAPH_UPDATE("Replacing uses of\n", *curr_node, "with\n", *new_node);
|
|
curr_node->output()->replaceAllUsesWith(new_node->output());
|
|
GRAPH_UPDATE("Deleting\n", *curr_node);
|
|
curr_node->destroy();
|
|
changed = true;
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
AliasDb* getOrCreateAliasDb() {
|
|
if (!aliasDb_) {
|
|
aliasDb_ = std::make_unique<AliasDb>(graph_);
|
|
}
|
|
return aliasDb_.get();
|
|
}
|
|
|
|
std::shared_ptr<Graph> graph_;
|
|
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
|
|
|
|
std::unordered_set<Node*> concated_outputs_;
|
|
std::unordered_map<Node*, Node*> concats_to_replace_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
bool EliminateConcatCommonInputs(const std::shared_ptr<Graph>& graph) {
|
|
GRAPH_DUMP("Before eliminating Concat common inputs", graph);
|
|
bool changed = ConcatCommonInputsEliminator(graph).run();
|
|
if (changed) {
|
|
GRAPH_DUMP("After eliminating Concat common inputs", graph);
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
namespace {
|
|
|
|
class ConcatExpander {
|
|
public:
|
|
explicit ConcatExpander(std::shared_ptr<Graph> graph)
|
|
: graph_(std::move(graph)) {}
|
|
|
|
void run() {
|
|
handleBlock(graph_->block());
|
|
cleanupExpandedCatOps();
|
|
GRAPH_DUMP("Before reusing copy buffers: ", graph_);
|
|
reuseBuffersInCopies();
|
|
}
|
|
|
|
private:
|
|
void handleBlock(Block* block) {
|
|
for (auto node : block->nodes()) {
|
|
if (node->kind() == aten::cat) {
|
|
expandCat(node);
|
|
}
|
|
for (Block* block : node->blocks()) {
|
|
handleBlock(block);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Expand cat node into multiple copy nodes.
|
|
//
|
|
// Example:
|
|
// %2 = aten::clamp(%0, ...)
|
|
// %3 = aten::clamp(%1, ...)
|
|
// %10 = prim::ListConstruct(%2, %3)
|
|
// %11 = aten::cat(%10, ...)
|
|
// ...
|
|
// = %11 ... // Use %11
|
|
//
|
|
// After expanding cat:
|
|
// %2 = aten::clamp(%0, ...)
|
|
// %3 = aten::clamp(%1, ...)
|
|
// %20 = aten::empty(...) // cat output buffer
|
|
// %21 = aten::slice(%20, ...) // slice for %2
|
|
// %22 = aten::copy_(%21, %2) // copy %2
|
|
// %23 = aten::slice(%20, ...) // slice for %3
|
|
// %24 = aten::copy_(%23, %3) // copy %3
|
|
// ...
|
|
// = %20 ... // Use %20 in place of %11
|
|
void expandCat(Node* node) {
|
|
GRAPH_DEBUG("Considering cat node for expansion: ", node);
|
|
// Do not optimize cat nodes whose inputs are mutated in the graph.
|
|
// TODO: Improve this by checking if it is mutated in the graph region
|
|
// where this optimization is applied.
|
|
if (getOrCreateAliasDb()->hasWriters(node->input(0))) {
|
|
return;
|
|
}
|
|
if (node->input(0)->node()->kind() != prim::ListConstruct) {
|
|
// Unknown form of input to `cat` op.
|
|
return;
|
|
}
|
|
if (!allShapesAreKnown(node)) {
|
|
// Can't expand when shapes are not known for the `cat` op.
|
|
return;
|
|
}
|
|
for (auto cat_inp : node->input(0)->node()->inputs()) {
|
|
if (!shapeIsKnown(cat_inp)) {
|
|
// Can't expand when shapes of the inputs to `cat` are not known.
|
|
return;
|
|
}
|
|
}
|
|
// TODO: Handle non-contiguous Tensors.
|
|
// For example, how to handle the cases where the inputs are all channels
|
|
// last?
|
|
|
|
auto maybe_cat_dim = constant_as<int64_t>(node->input(1));
|
|
if (!maybe_cat_dim) {
|
|
// Can't expand when cat dimension is not a constant.
|
|
return;
|
|
}
|
|
auto cat_dim_value = maybe_cat_dim.value();
|
|
auto cat_dim = node->input(1);
|
|
|
|
// Set the insertion point to the current `cat` node.
|
|
WithInsertPoint guard(node);
|
|
auto none = graph_->insertConstant(IValue());
|
|
auto one = graph_->insertConstant(1);
|
|
|
|
// Insert the constants needed for the `cat` output buffer size.
|
|
auto tensortype = node->output()->type()->expect<TensorType>();
|
|
TORCH_INTERNAL_ASSERT(tensortype);
|
|
auto tensortype_sizes = tensortype->sizes();
|
|
std::vector<Value*> cat_out_size;
|
|
for (size_t i = 0; i < tensortype_sizes.size(); ++i) {
|
|
cat_out_size.push_back(graph_->insertConstant(tensortype_sizes[i]));
|
|
}
|
|
|
|
// Create a list of int for `cat` output buffer size.
|
|
auto cat_out_size_list = graph_->createList(IntType::get(), cat_out_size);
|
|
cat_out_size_list->insertBefore(node);
|
|
|
|
// Create an empty buffer to be used as `cat` output buffer.
|
|
// TODO: Handle tensors with different dtype, layout, device, memory
|
|
// format, etc.
|
|
auto cat_out_empty = graph_->create(
|
|
aten::empty,
|
|
{cat_out_size_list->output(), none, none, none, none, none});
|
|
cat_out_empty->insertBefore(node);
|
|
|
|
// For every input to this `cat` node:
|
|
// * Create a slice of `cat` output buffer.
|
|
auto cat_out_value = cat_out_empty->output();
|
|
auto cat_inp_list = node->input(0)->node();
|
|
int64_t start_idx = 0;
|
|
auto start = graph_->insertConstant(start_idx);
|
|
for (auto cat_inp : cat_inp_list->inputs()) {
|
|
// Create a slice of the cat output buffer that correspond to
|
|
// this input size and position in the output.
|
|
auto cat_inp_tensor_type =
|
|
dynamic_cast<TensorType*>(cat_inp->type().get());
|
|
TORCH_INTERNAL_ASSERT(cat_inp_tensor_type);
|
|
TORCH_INTERNAL_ASSERT(cat_inp_tensor_type->dim());
|
|
auto cat_inp_tensortype_sizes = cat_inp_tensor_type->sizes();
|
|
auto end_idx = start_idx + *cat_inp_tensortype_sizes[cat_dim_value];
|
|
auto end = graph_->insertConstant(end_idx);
|
|
|
|
auto slice = graph_->create(
|
|
aten::slice, {cat_out_value, cat_dim, start, end, one});
|
|
GRAPH_UPDATE("Inserting\n", *slice, "before\n", *node);
|
|
slice->insertBefore(node);
|
|
slices_added_.push_back(slice);
|
|
|
|
// Insert a copy from this input to the output slice.
|
|
auto copy = graph_->create(aten::copy_, {slice->output(), cat_inp});
|
|
GRAPH_UPDATE("Inserting\n", *copy, "before\n", *node);
|
|
copy->insertBefore(node);
|
|
copies_added_.push_back(copy);
|
|
|
|
start_idx = end_idx;
|
|
start = end;
|
|
}
|
|
|
|
// Replace the uses of `cat` node with the cat output buffer.
|
|
replace_uses_with_[node->output()] = cat_out_value;
|
|
nodes_to_remove_.insert(node);
|
|
}
|
|
|
|
bool shapeIsKnown(Value* v) {
|
|
if (v->type()->cast<TensorType>()) {
|
|
if (!v->isCompleteTensor()) {
|
|
return false;
|
|
}
|
|
if (*v->type()->castRaw<TensorType>()->dim() == 0) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
bool allShapesAreKnown(Node* node) {
|
|
// TODO: Relax the checks to support dynamic shapes
|
|
for (Value* input : node->inputs()) {
|
|
if (!shapeIsKnown(input)) {
|
|
return false;
|
|
}
|
|
}
|
|
for (Value* output : node->outputs()) {
|
|
if (!shapeIsKnown(output)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void cleanupExpandedCatOps() {
|
|
for (auto it : replace_uses_with_) {
|
|
GRAPH_UPDATE(
|
|
"Replacing uses of\n",
|
|
*it.first->node(),
|
|
"with\n",
|
|
*it.second->node());
|
|
it.first->replaceAllUsesWith(it.second);
|
|
}
|
|
for (auto n : nodes_to_remove_) {
|
|
removeCatNodeFromGraph(n);
|
|
}
|
|
}
|
|
|
|
void moveBefore(Node* node, Node* before) {
|
|
// In order to move a node before another node, we need to move
|
|
// all the nodes it depends on as well.
|
|
for (auto inp : node->inputs()) {
|
|
moveBefore(inp->node(), before);
|
|
}
|
|
node->moveBefore(before);
|
|
}
|
|
|
|
// Reuse buffers in copies wherever possible.
|
|
//
|
|
// For example, consider the following sequence of ops:
|
|
// %10 = prim::ListConstruct(%0, %1)
|
|
// %11 = aten::cat(%10, ...)
|
|
// ...
|
|
// %12 = prim::ListConstruct(%11, %2) // Uses the result of above cat
|
|
// %13 = aten::cat(%12, ...)
|
|
//
|
|
// Once these cat ops are expanded into copies, we will have two buffers; one
|
|
// for %11 and another for %13. This can be optimized by using only one
|
|
// buffer. We can only have the buffer that represents %13 and use a view
|
|
// (slice) of that one as the buffer for %11.
|
|
//
|
|
// If any of the copies added earlier has `aten::empty` as its source,
|
|
// those cases can be replaced with a single buffer.
|
|
//
|
|
// Example:
|
|
// %20 = aten::empty(...) // cat.1 output buffer
|
|
// %21 = aten::slice(%20, ...)
|
|
// %22 = aten::copy_(%21, %2)
|
|
// %23 = aten::slice(%20, ...)
|
|
// %24 = aten::copy_(%23, %3)
|
|
// ...
|
|
// %30 = aten::empty(...) // cat.2 output buffer
|
|
// %31 = aten::slice(%30, ...)
|
|
// %32 = aten::copy_(%31, %20) // src of copy is aten::empty
|
|
// // so, we reuse this buffer above
|
|
// %33 = aten::slice(%30, ...)
|
|
// %34 = aten::copy_(%33, %4)
|
|
//
|
|
// After reusing copy buffers:
|
|
// %30 = aten::empty(...) // cat.2 output buffer
|
|
// %31 = aten::slice(%30, ...) // move %31 and inputs before %20
|
|
// %21 = aten::slice(%31, ...) // use %31 in place of %20
|
|
// %22 = aten::copy_(%21, %2)
|
|
// %23 = aten::slice(%31, ...) // use %31 in place of %20
|
|
// %24 = aten::copy_(%23, %3)
|
|
// ...
|
|
// ... // copy to %31 is now removed
|
|
// %33 = aten::slice(%30, ...)
|
|
// %34 = aten::copy_(%33, %4)
|
|
void reuseBuffersInCopies() {
|
|
for (auto copy : copies_added_) {
|
|
auto src = copy->input(1);
|
|
auto dst = copy->input(0);
|
|
if (src->node()->kind() != aten::empty) {
|
|
continue;
|
|
}
|
|
|
|
// Move the destination node before the source.
|
|
GRAPH_UPDATE("Moving\n", *dst->node(), "before\n", *src->node());
|
|
moveBefore(dst->node(), src->node());
|
|
|
|
GRAPH_UPDATE("Replacing\n", *src->node(), "with\n", *dst->node());
|
|
src->replaceAllUsesWith(dst);
|
|
|
|
GRAPH_UPDATE("Deleting\n", *src->node());
|
|
src->node()->destroy();
|
|
|
|
GRAPH_UPDATE("Deleting\n", *copy);
|
|
copy->destroy();
|
|
}
|
|
}
|
|
|
|
AliasDb* getOrCreateAliasDb() {
|
|
if (!aliasDb_) {
|
|
aliasDb_ = std::make_unique<AliasDb>(graph_);
|
|
}
|
|
return aliasDb_.get();
|
|
}
|
|
|
|
std::shared_ptr<Graph> graph_;
|
|
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
|
|
|
|
std::unordered_set<Node*> nodes_to_remove_;
|
|
std::unordered_map<Value*, Value*> replace_uses_with_;
|
|
std::vector<Node*> copies_added_;
|
|
std::vector<Node*> slices_added_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void ExpandConcatAndEliminateRedundancy(const std::shared_ptr<Graph>& graph) {
|
|
ConcatExpander(graph).run();
|
|
GRAPH_DUMP("After expanding Concat and eliminating redundancy", graph);
|
|
}
|
|
|
|
namespace {
|
|
|
|
size_t determineUsageIdx(Value* value, Node* user) {
|
|
const auto idx =
|
|
std::find(user->inputs().begin(), user->inputs().end(), value) -
|
|
user->inputs().begin();
|
|
using c10::ssize;
|
|
TORCH_CHECK(idx != ssize(user->inputs()));
|
|
return idx;
|
|
}
|
|
|
|
std::vector<Value*> getConcatInputs(Node* concat) {
|
|
TORCH_CHECK(concat->kind() == aten::cat);
|
|
auto* list = concat->input(0);
|
|
auto* list_construct = list->node();
|
|
TORCH_CHECK(list_construct->kind() == prim::ListConstruct);
|
|
return list_construct->inputs().vec();
|
|
}
|
|
|
|
class ConcatCombiner {
|
|
public:
|
|
explicit ConcatCombiner(std::shared_ptr<Graph> graph)
|
|
: graph_(std::move(graph)), aliasDb_(graph_) {}
|
|
|
|
bool run() {
|
|
collectOptimizableConcats();
|
|
bool changed = combineConcats();
|
|
if (changed) {
|
|
EliminateDeadCode(graph_);
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
private:
|
|
// Given a concat node, see if it can be optimized with another.
|
|
// If so, add a CombinablePair to combinable_concats_.
|
|
void handleConcat(Node* node) {
|
|
auto* list = node->input(0);
|
|
auto* list_node = list->node();
|
|
|
|
const auto dim_opt = toIValue(node->input(1));
|
|
// We need to be able to determine dim statically to match it with another
|
|
// concat.
|
|
if (!dim_opt || !dim_opt->isInt()) {
|
|
return;
|
|
}
|
|
const auto dim = dim_opt->toInt();
|
|
|
|
// Check that the input of this node is an unmodified list construct
|
|
if (list_node->kind() != prim::ListConstruct ||
|
|
!aliasDb_.couldMoveBeforeTopologically(list_node, node)) {
|
|
return;
|
|
}
|
|
|
|
// Check that the only output of this node is used in an unmodified list
|
|
// construct.
|
|
const auto& concat_uses = node->output()->uses();
|
|
if (concat_uses.size() != 1) {
|
|
return;
|
|
}
|
|
|
|
auto* next_list = concat_uses[0].user;
|
|
if (next_list->kind() != prim::ListConstruct) {
|
|
return;
|
|
}
|
|
|
|
const auto& next_list_uses = next_list->output()->uses();
|
|
if (next_list_uses.size() != 1) {
|
|
return;
|
|
}
|
|
|
|
auto* next_concat = next_list_uses[0].user;
|
|
|
|
if (next_concat->kind() == aten::cat) {
|
|
// Dimension must be determined statically and match the one we've already
|
|
// seen.
|
|
const auto next_dim_opt = toIValue(next_concat->input(1));
|
|
if (!next_dim_opt || next_dim_opt->toInt() != dim) {
|
|
return;
|
|
}
|
|
combinable_concats_.emplace_back(
|
|
node, next_concat, determineUsageIdx(node->output(), next_list));
|
|
}
|
|
}
|
|
|
|
void collectOptimizableConcats() {
|
|
DepthFirstGraphNodeIterator graph_it(graph_);
|
|
for (auto* node = graph_it.next(); node != nullptr;
|
|
node = graph_it.next()) {
|
|
if (node->kind() == aten::cat) {
|
|
handleConcat(node);
|
|
}
|
|
}
|
|
}
|
|
|
|
Node* createListConstruct(const std::deque<Value*>& inputs) {
|
|
auto* output = graph_->create(prim::ListConstruct);
|
|
for (auto* v : inputs) {
|
|
output->addInput(v);
|
|
}
|
|
return output;
|
|
}
|
|
|
|
using ListConstructInputs = std::shared_ptr<std::deque<Value*>>;
|
|
// Construct a map (concat node) -> (new list inputs for this node).
|
|
// std::deque is used so we can do O(1) insertions to the front.
|
|
std::unordered_map<Node*, ListConstructInputs> getListConstructInputs() {
|
|
std::unordered_map<Node*, ListConstructInputs> cur_list_construct_inputs;
|
|
for (const auto& combinable : combinable_concats_) {
|
|
// Combine the list inputs of first_concat with those of second_concat
|
|
const auto& inputs_to_add = getConcatInputs(combinable.second_concat);
|
|
|
|
auto it = cur_list_construct_inputs.find(combinable.first_concat);
|
|
std::shared_ptr<std::deque<Value*>> cur_list;
|
|
if (it != cur_list_construct_inputs.end()) {
|
|
cur_list = it->second;
|
|
// We're moving all inputs to second_concat.
|
|
cur_list_construct_inputs.erase(combinable.first_concat);
|
|
} else {
|
|
cur_list = std::make_shared<std::deque<Value*>>();
|
|
}
|
|
cur_list_construct_inputs.emplace(combinable.second_concat, cur_list);
|
|
|
|
// If cur_list is not empty, it's guaranteed to already contain all of
|
|
// first_concat's inputs.
|
|
if (cur_list->empty()) {
|
|
const auto& starting_values = getConcatInputs(combinable.first_concat);
|
|
cur_list->insert(
|
|
cur_list->end(), starting_values.begin(), starting_values.end());
|
|
}
|
|
|
|
cur_list->insert(
|
|
cur_list->begin(),
|
|
inputs_to_add.begin(),
|
|
inputs_to_add.begin() + combinable.idx);
|
|
|
|
cur_list->insert(
|
|
cur_list->end(),
|
|
inputs_to_add.begin() + combinable.idx + 1,
|
|
inputs_to_add.end());
|
|
}
|
|
return cur_list_construct_inputs;
|
|
}
|
|
|
|
bool combineConcats() {
|
|
if (combinable_concats_.empty()) {
|
|
return false;
|
|
}
|
|
|
|
auto list_construct_inputs = getListConstructInputs();
|
|
|
|
for (const auto& node_and_new_list : list_construct_inputs) {
|
|
auto* node = node_and_new_list.first;
|
|
auto& inputs = node_and_new_list.second;
|
|
|
|
auto* new_list_construct = createListConstruct(*inputs);
|
|
auto* old_list_construct = node->input(0)->node();
|
|
new_list_construct->output()->setType(
|
|
old_list_construct->output()->type());
|
|
new_list_construct->insertBefore(node);
|
|
old_list_construct->replaceAllUsesWith(new_list_construct);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Represents an optimizable pair of concat nodes.
|
|
// - first_concat must appear before second_concat
|
|
// - idx is the index where first_concat's inputs must be inserted into
|
|
// second_concat's new inputs.
|
|
// Example:
|
|
// %inputs.1 = prim::ListConstruct(%0, %0)
|
|
// %concat.1 = aten::cat(%inputs.1, %dim)
|
|
// %inputs.2 = prim::ListConstruct(%1, %concat.1, %1)
|
|
// %concat.2 = aten::cat(%inputs.2, %dim)
|
|
// -> first_concat = &concat.1, second_concat = &concat.2, idx = 1
|
|
struct CombinableConcat {
|
|
CombinableConcat(Node* a, Node* b, size_t i)
|
|
: first_concat(a), second_concat(b), idx(i) {}
|
|
|
|
Node* first_concat;
|
|
Node* second_concat;
|
|
size_t idx;
|
|
};
|
|
|
|
std::vector<CombinableConcat> combinable_concats_;
|
|
|
|
std::shared_ptr<Graph> graph_;
|
|
AliasDb aliasDb_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
bool CombineConcats(const std::shared_ptr<Graph>& graph) {
|
|
bool changed = ConcatCombiner(graph).run();
|
|
GRAPH_DUMP("After combining concats", graph);
|
|
return changed;
|
|
}
|
|
|
|
} // namespace torch::jit
|