mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Follows #132604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132753 Approved by: https://github.com/Skylion007
23 lines
580 B
C++
23 lines
580 B
C++
#include <torch/csrc/jit/passes/remove_expands.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
static void RemoveExpands(Block* block) {
|
|
for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
|
|
++it) {
|
|
for (auto sub : it->blocks())
|
|
RemoveExpands(sub);
|
|
|
|
if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) {
|
|
it->output()->replaceAllUsesWith(it->namedInput(attr::self));
|
|
it.destroyCurrent();
|
|
}
|
|
}
|
|
}
|
|
|
|
void RemoveExpands(const std::shared_ptr<Graph>& graph) {
|
|
RemoveExpands(graph->block());
|
|
}
|
|
|
|
} // namespace torch::jit
|