Add more list peephole idioms (#48268)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48268

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D25104617

Pulled By: eellison

fbshipit-source-id: b41c03d5da6e9b88acf21a859f61c5c70608c150
This commit is contained in:
Elias Ellison 2020-12-17 20:24:12 -08:00 committed by Facebook GitHub Bot
parent 39d3578e91
commit 9058040527

View File

@ -89,6 +89,42 @@ struct PeepholeOptimizeListIdiomsImpl {
}
}
}
} else if (node->kind() == prim::ListUnpack) {
auto list_creation_node = first_input->node();
if (list_creation_node->kind() == prim::ListConstruct) {
// if sizes are unequal it's a runtime error
if (list_creation_node->inputs().size() != node->outputs().size()) {
continue;
}
for (size_t i = 0; i < node->outputs().size(); ++i) {
node->output(i)->replaceAllUsesWith(list_creation_node->inputs().at(i));
}
}
} else if (node->kind() == aten::add) {
if (node->inputs().size() != 2) {
continue;
}
auto second_input = node->inputs().at(1);
// already checked first, need to check second
if (mutated_lists_.count(second_input)) {
continue;
}
if (first_input->node()->kind() != prim::ListConstruct || second_input->node()->kind() != prim::ListConstruct) {
continue;
}
WithInsertPoint guard(node);
auto list_construct = graph_->insertNode(graph_->create(prim::ListConstruct));
list_construct->output()->setType(node->output()->type());
for (Value * v: first_input->node()->inputs()) {
list_construct->addInput(v);
}
for (Value * v: second_input->node()->inputs()) {
list_construct->addInput(v);
}
node->output()->replaceAllUsesWith(list_construct->output());
if (mutated_lists_.count(node->output())) {
mutated_lists_.insert(list_construct->output());
}
}
}
}