mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add more list peephole idioms (#48268)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48268 Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D25104617 Pulled By: eellison fbshipit-source-id: b41c03d5da6e9b88acf21a859f61c5c70608c150
This commit is contained in:
parent
39d3578e91
commit
9058040527
|
|
@ -89,6 +89,42 @@ struct PeepholeOptimizeListIdiomsImpl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (node->kind() == prim::ListUnpack) {
|
||||||
|
auto list_creation_node = first_input->node();
|
||||||
|
if (list_creation_node->kind() == prim::ListConstruct) {
|
||||||
|
// if sizes are unequal it's a runtime error
|
||||||
|
if (list_creation_node->inputs().size() != node->outputs().size()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
||||||
|
node->output(i)->replaceAllUsesWith(list_creation_node->inputs().at(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (node->kind() == aten::add) {
|
||||||
|
if (node->inputs().size() != 2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto second_input = node->inputs().at(1);
|
||||||
|
// already checked first, need to check second
|
||||||
|
if (mutated_lists_.count(second_input)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (first_input->node()->kind() != prim::ListConstruct || second_input->node()->kind() != prim::ListConstruct) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
WithInsertPoint guard(node);
|
||||||
|
auto list_construct = graph_->insertNode(graph_->create(prim::ListConstruct));
|
||||||
|
list_construct->output()->setType(node->output()->type());
|
||||||
|
for (Value * v: first_input->node()->inputs()) {
|
||||||
|
list_construct->addInput(v);
|
||||||
|
}
|
||||||
|
for (Value * v: second_input->node()->inputs()) {
|
||||||
|
list_construct->addInput(v);
|
||||||
|
}
|
||||||
|
node->output()->replaceAllUsesWith(list_construct->output());
|
||||||
|
if (mutated_lists_.count(node->output())) {
|
||||||
|
mutated_lists_.insert(list_construct->output());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user