mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Dont optimize shape info in batch_mm (#44565)
Summary: We run remove profile nodes and specialize types before batch_mm, so we cannot run peepholes on the type information of tensors since these properties have not been guarded to be guaranteed to be correct. Pull Request resolved: https://github.com/pytorch/pytorch/pull/44565 Reviewed By: albanD Differential Revision: D23661538 Pulled By: eellison fbshipit-source-id: 0dd23a65714f047f49b4db4ec582b21870925fe1
This commit is contained in:
parent
e261e0953e
commit
856510c96d
|
|
@ -146,6 +146,19 @@ class TestProfiler(JitTestCase):
|
|||
g = torch.jit.last_executed_optimized_graph()
|
||||
FileCheck().check_not("TensorExpr").run(g)
|
||||
|
||||
def test_not_optimizing_property(self):
|
||||
@torch.jit.script
|
||||
def foo(x, y):
|
||||
return x + y + 1 + 2 + 3, x.size()
|
||||
|
||||
x = torch.ones(1)
|
||||
foo(x, x)
|
||||
foo(x, x)
|
||||
g = torch.jit.last_executed_optimized_graph()
|
||||
FileCheck().check("aten::size").run(g)
|
||||
x = torch.ones([2, 3, 5])
|
||||
self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size()))
|
||||
|
||||
def test_fallback_graph_not_specialized(self):
|
||||
@torch.jit.script
|
||||
def foo(a, b):
|
||||
|
|
|
|||
|
|
@ -475,7 +475,9 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
|
|||
EliminateDeadCode(graph);
|
||||
// It's possible that transpose rearrangements have created sequences of
|
||||
// consecutive transposes that didn't exist before.
|
||||
PeepholeOptimize(graph);
|
||||
|
||||
// tensor type properties are not guaranteed to be correct
|
||||
PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user