mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
lower batchmm to non-diff optimization (#19987)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19987 ghimport-source-id: ca4c38312bd56d8a71f1925297deee7f64f573d3 Differential Revision: D15190356 Pulled By: wanchaol fbshipit-source-id: 761edb08c670fcbc24a06a5b11ceddf311f75884
This commit is contained in:
parent
0c5dc965a4
commit
8fbde94664
|
|
@ -638,16 +638,17 @@ struct GraphExecutorImpl {
|
|||
UnrollLoops(graph);
|
||||
EliminateCommonSubexpression(graph);
|
||||
|
||||
// Rewrite subgraphs with many MMs into expressions that batch them.
|
||||
BatchMM(graph);
|
||||
|
||||
CheckInplace(graph);
|
||||
}
|
||||
|
||||
void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
|
||||
// run custom passes that different backends can register
|
||||
for (const auto& pass : getCustomPasses()) {
|
||||
pass(graph);
|
||||
}
|
||||
// Rewrite subgraphs with many MMs into expressions that batch them.
|
||||
BatchMM(graph);
|
||||
|
||||
FuseGraph(graph);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user