pytorch/torch/csrc/jit/runtime/static/passes.cpp
Ansha Yu 07978bd62e [static runtime] fuse inference ops (1) (#48948)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48948

Fuse inference ops for the following inside static runtime:
ConcatAddMulReplaceNaNClip
CastedBatchOneHotLengths
ConcatBatchMatMulBatchGather

TODO:
1. add unit tests
2. add more restrictions on the graph transform (e.g. check inputs, check outputs not used elsewhere)

Test Plan:
Run adindexer model with static runtime and fusion; check ops
```
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 numactl -m 0 -C 3 ./buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench --scripted_model=/data/users/ansha/tmp/adindexer/traced_precomputation2.pt --pt_inputs=/data/users/ansha/tmp/adindexer/merge/container_precomputation_bs1.pt --iters=3000 --warmup_iters=10000  --num_threads=1 --pred_net=/data/users/ansha/tmp/adindexer/precomputation_merge_net.pb --c2_inputs=/data/users/ansha/tmp/adindexer/merge/c2_inputs_precomputation_bs1.pb --c2_sigrid_transforms_opt=1 --c2_use_memonger=1 --c2_weights=/data/users/ansha/tmp/adindexer/merge/c2_weights_precomputation.pb --pt_enable_static_runtime
```
transformed model graph contains the fused ops: P151559641

Results before fusion: P151567611
Results after fusion: P151566783 (8% speedup for bs=20, 14% speedup for bs=1)

Reviewed By: hlu1

Differential Revision: D25224107

fbshipit-source-id: c8442e8ceb018879c61ce564367b1c1b9412601b
2020-12-08 05:54:49 -08:00

84 lines
2.9 KiB
C++

#include <torch/csrc/jit/runtime/static/passes.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch {
namespace jit {
void ConcatAddMulReplaceNaNClip(std::shared_ptr<torch::jit::Graph>& graph) {
// TODO:: check restrictions for inputs; outputs not used elsewhere
std::string pattern = R"IR(
graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
%y0 = aten::cat(%a, %b)
%y1 = aten::add(%y0, %c, %d)
%y2 = aten::mul(%y1, %e)
%y3 = aten::nan_to_num(%y2, %f, %g, %h)
%res = aten::clamp(%y3, %i, %j)
return (%res))IR";
std::string pattern2 = R"IR(
graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
%y0 = aten::cat(%a, %b)
%y1 = aten::add(%y0, %c, %d)
%y2 = aten::mul(%y1, %e)
%y3 = aten::nan_to_num_(%y2, %f, %g, %h)
%res = aten::clamp(%y3, %i, %j)
return (%res))IR";
std::string fused_pattern = R"IR(
graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
%res = fb::concat_add_mul_replacenan_clip(%c, %e, %a, %i, %j)
return (%res))IR";
SubgraphRewriter fuse;
fuse.RegisterRewritePattern(pattern, fused_pattern);
fuse.runOnGraph(graph);
fuse.RegisterRewritePattern(pattern2, fused_pattern);
fuse.runOnGraph(graph);
}
void CastedBatchOneHotLengths(std::shared_ptr<torch::jit::Graph>& graph) {
// TODO:: check restrictions for inputs; outputs not used elsewhere
std::string pattern = R"IR(
graph(%a, %b, %c, %d, %e, %f, %g):
%y0 : Tensor = aten::to(%a, %b, %c, %c, %d)
%y1 : Tensor = fb::batch_one_hot_lengths(%y0, %e, %f)
%res : Tensor = aten::to(%y1, %g, %c, %c, %d)
return (%res))IR";
std::string fused_pattern = R"IR(
graph(%a, %b, %c, %d, %e, %f, %g):
%res : Tensor = fb::casted_batch_one_hot_lengths(%a, %e, %f)
return (%res))IR";
SubgraphRewriter fuse;
fuse.RegisterRewritePattern(pattern, fused_pattern);
fuse.runOnGraph(graph);
}
void ConcatBatchMatMulBatchGather(std::shared_ptr<torch::jit::Graph>& graph) {
// TODO:: check restrictions for inputs; outputs not used elsewhere
std::string pattern = R"IR(
graph(%a, %b, %c, %d, %e, %f):
%y0 : Tensor = aten::stack(%a, %b)
%y1 : Tensor = aten::transpose(%y0, %b, %c)
%y2 : Tensor = aten::bmm(%y0, %y1)
%y3 : Tensor = aten::flatten(%y2, %d, %e)
%res : Tensor = aten::index_select(%y3, %b, %f)
return (%res))IR";
std::string fused_pattern = R"IR(
graph(%a, %b, %c, %d, %e, %f):
%res : Tensor = fb::concat_batch_matmul_batch_gather(%f, %a)
return (%res))IR";
SubgraphRewriter fuse;
fuse.RegisterRewritePattern(pattern, fused_pattern);
fuse.runOnGraph(graph);
}
void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
#ifdef FBCODE_CAFFE2
ConcatAddMulReplaceNaNClip(graph);
CastedBatchOneHotLengths(graph);
ConcatBatchMatMulBatchGather(graph);
#endif
}
} // namespace jit
} // namespace torch