[JIT] Dont optimize shape info in batch_mm (#44565)

Summary:
We run remove profile nodes and specialize types before batch_mm, so we cannot run peepholes on the type information of tensors since these properties have not been guarded to be guaranteed to be correct.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44565

Reviewed By: albanD

Differential Revision: D23661538

Pulled By: eellison

fbshipit-source-id: 0dd23a65714f047f49b4db4ec582b21870925fe1
This commit is contained in:
Elias Ellison 2020-09-14 12:23:09 -07:00 committed by Facebook GitHub Bot
parent e261e0953e
commit 856510c96d
2 changed files with 16 additions and 1 deletions

View File

@ -146,6 +146,19 @@ class TestProfiler(JitTestCase):
g = torch.jit.last_executed_optimized_graph()
FileCheck().check_not("TensorExpr").run(g)
def test_not_optimizing_property(self):
@torch.jit.script
def foo(x, y):
return x + y + 1 + 2 + 3, x.size()
x = torch.ones(1)
foo(x, x)
foo(x, x)
g = torch.jit.last_executed_optimized_graph()
FileCheck().check("aten::size").run(g)
x = torch.ones([2, 3, 5])
self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size()))
def test_fallback_graph_not_specialized(self):
@torch.jit.script
def foo(a, b):

View File

@ -475,7 +475,9 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
EliminateDeadCode(graph);
// It's possible that transpose rearrangements have created sequences of
// consecutive transposes that didn't exist before.
PeepholeOptimize(graph);
// tensor type properties are not guaranteed to be correct
PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
}
} // namespace jit