#include namespace torch::jit { static void RemoveExpands(Block* block) { for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; ++it) { for (auto sub : it->blocks()) RemoveExpands(sub); if (it->kind() == aten::expand && it->get(attr::implicit) == true) { it->output()->replaceAllUsesWith(it->namedInput(attr::self)); it.destroyCurrent(); } } } void RemoveExpands(const std::shared_ptr& graph) { RemoveExpands(graph->block()); } } // namespace torch::jit