diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index 484d89747a9..29e893ffc84 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -1,5 +1,6 @@ -#include #include +#include +#include #include #include @@ -13,8 +14,11 @@ struct GuardElimination { void run() { moveGuardsToDefs(graph_->block()); + GRAPH_DUMP("After moveGuardsToDefs", graph_); coalesceGuards(graph_->block()); + GRAPH_DUMP("After coalesceGuards", graph_); eliminateRedundantGuards(graph_->block()); + GRAPH_DUMP("After eliminateRedundantGuards", graph_); } void moveGuardsToDefs(Block* b) { @@ -31,7 +35,14 @@ struct GuardElimination { if (guardee->owningBlock() != n->owningBlock()) { guardee = *n->owningBlock()->nodes().begin(); } - aliasDb_->moveAfterTopologicallyValid(n, guardee); + bool moved = aliasDb_->moveAfterTopologicallyValid(n, guardee); + if (moved) { + GRAPH_UPDATE( + "Moved ", + n->output()->debugName(), + " to ", + n->inputs().at(0)->debugName()); + } } else { it++; for (Block* ib : n->blocks()) { @@ -55,6 +66,11 @@ struct GuardElimination { if (inputs_to_guards.count(n->input())) { auto prev = inputs_to_guards[n->input()]; n->output()->replaceAllUsesWith(prev->output()); + GRAPH_UPDATE( + "Replacing ", + n->output()->debugName(), + " with ", + prev->output()->debugName()); it.destroyCurrent(); } else { inputs_to_guards.insert({n->input(), n}); @@ -94,6 +110,8 @@ struct GuardElimination { auto pttp = n->output()->type(); n->output()->replaceAllUsesWith(n->inputs().at(0)); n->inputs().at(0)->setType(pttp); + GRAPH_UPDATE( + "Eliminating the redundant guard ", n->output()->debugName()); it.destroyCurrent(); } else { it++;