mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49181 Fuse ClipRangesGatherSigridHash Test Plan: ``` MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 numactl -m 0 -C 3 ./buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench --scripted_model=/data/users/ansha/tmp/adindexer/merge/traced_merge_dper_fixes.pt --pt_inputs=/data/users/ansha/tmp/adindexer/merge/container_precomputation_bs1.pt --iters=30000 --warmup_iters=10000 --num_threads=1 --pred_net=/data/users/ansha/tmp/adindexer/precomputation_merge_net.pb --c2_inputs=/data/users/ansha/tmp/adindexer/merge/c2_inputs_precomputation_bs1.pb --c2_sigrid_transforms_opt=1 --c2_use_memonger=1 --c2_weights=/data/users/ansha/tmp/adindexer/merge/c2_weights_precomputation.pb --pt_enable_static_runtime --pt_cleanup_activations=true --pt_enable_out_variant=true --do_profile --compare_results ``` Verify op fused: Node #3: 0.00104917 ms/iter, %173 : Tensor, %174 : Tensor = fb::clip_ranges_gather_sigrid_hash_offsets(%75, %76, %39, %40, %41, %38, %26) Before: 0.0919786 After: 0.0911792 Reviewed By: hlu1 Differential Revision: D25468225 fbshipit-source-id: 36bd91c140eaa57cb42cdaad46d878b94f162a9d
120 lines
4.4 KiB
C++
120 lines
4.4 KiB
C++
#include <torch/csrc/jit/runtime/static/passes.h>
|
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void ConcatAddMulReplaceNaNClip(std::shared_ptr<torch::jit::Graph>& graph) {
|
|
// TODO:: check restrictions for inputs; outputs not used elsewhere
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
|
|
%y0 = aten::cat(%a, %b)
|
|
%y1 = aten::add(%y0, %c, %d)
|
|
%y2 = aten::mul(%y1, %e)
|
|
%y3 = aten::nan_to_num(%y2, %f, %g, %h)
|
|
%res = aten::clamp(%y3, %i, %j)
|
|
return (%res))IR";
|
|
std::string pattern2 = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
|
|
%y0 = aten::cat(%a, %b)
|
|
%y1 = aten::add(%y0, %c, %d)
|
|
%y2 = aten::mul(%y1, %e)
|
|
%y3 = aten::nan_to_num_(%y2, %f, %g, %h)
|
|
%res = aten::clamp(%y3, %i, %j)
|
|
return (%res))IR";
|
|
std::string fused_pattern = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
|
|
%res = fb::concat_add_mul_replacenan_clip(%c, %e, %a, %i, %j)
|
|
return (%res))IR";
|
|
|
|
SubgraphRewriter fuse;
|
|
fuse.RegisterRewritePattern(pattern, fused_pattern);
|
|
fuse.runOnGraph(graph);
|
|
|
|
fuse.RegisterRewritePattern(pattern2, fused_pattern);
|
|
fuse.runOnGraph(graph);
|
|
}
|
|
|
|
void CastedBatchOneHotLengths(std::shared_ptr<torch::jit::Graph>& graph) {
|
|
// TODO:: check restrictions for inputs; outputs not used elsewhere
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f, %g):
|
|
%y0 : Tensor = aten::to(%a, %b, %c, %c, %d)
|
|
%y1 : Tensor = fb::batch_one_hot_lengths(%y0, %e, %f)
|
|
%res : Tensor = aten::to(%y1, %g, %c, %c, %d)
|
|
return (%res))IR";
|
|
std::string fused_pattern = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f, %g):
|
|
%res : Tensor = fb::casted_batch_one_hot_lengths(%a, %e, %f)
|
|
return (%res))IR";
|
|
SubgraphRewriter fuse;
|
|
fuse.RegisterRewritePattern(pattern, fused_pattern);
|
|
fuse.runOnGraph(graph);
|
|
}
|
|
|
|
void ConcatBatchMatMulBatchGather(std::shared_ptr<torch::jit::Graph>& graph) {
|
|
// TODO:: check restrictions for inputs; outputs not used elsewhere
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f):
|
|
%y0 : Tensor = aten::stack(%a, %b)
|
|
%y1 : Tensor = aten::transpose(%y0, %b, %c)
|
|
%y2 : Tensor = aten::bmm(%y0, %y1)
|
|
%y3 : Tensor = aten::flatten(%y2, %d, %e)
|
|
%res : Tensor = aten::index_select(%y3, %b, %f)
|
|
return (%res))IR";
|
|
std::string fused_pattern = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f):
|
|
%res : Tensor = fb::concat_batch_matmul_batch_gather(%f, %a)
|
|
return (%res))IR";
|
|
SubgraphRewriter fuse;
|
|
fuse.RegisterRewritePattern(pattern, fused_pattern);
|
|
fuse.runOnGraph(graph);
|
|
}
|
|
|
|
void ClipRangesGatherRangesLengthsToOffsets(
|
|
std::shared_ptr<torch::jit::Graph>& graph) {
|
|
// TODO:: check restrictions for inputs; outputs not used elsewhere
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b, %c, %d):
|
|
%y0 : Tensor = fb::clip_ranges(%b, %c)
|
|
%y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0)
|
|
%y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
|
|
return (%y3, %y1))IR";
|
|
std::string fused_pattern = R"IR(
|
|
graph(%a, %b, %c, %d):
|
|
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d)
|
|
return (%y1, %y0))IR";
|
|
SubgraphRewriter fuse;
|
|
fuse.RegisterRewritePattern(pattern, fused_pattern);
|
|
fuse.runOnGraph(graph);
|
|
}
|
|
|
|
void ClipRangesGatherSigridHash(std::shared_ptr<torch::jit::Graph>& graph) {
|
|
// TODO:: check restrictions for inputs; outputs not used elsewhere
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f, %g):
|
|
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d)
|
|
%y2 : Tensor = fb::sigrid_hash(%y0, %e, %f, %g)
|
|
return (%y2, %y1))IR";
|
|
std::string fused_pattern = R"IR(
|
|
graph(%a, %b, %c, %d, %e, %f, %g):
|
|
%off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_offsets(%b, %a, %c, %e, %f, %g, %d)
|
|
return (%out, %off))IR";
|
|
SubgraphRewriter fuse;
|
|
fuse.RegisterRewritePattern(pattern, fused_pattern);
|
|
fuse.runOnGraph(graph);
|
|
}
|
|
|
|
void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
|
|
#ifdef FBCODE_CAFFE2
|
|
ConcatAddMulReplaceNaNClip(graph);
|
|
CastedBatchOneHotLengths(graph);
|
|
ConcatBatchMatMulBatchGather(graph);
|
|
ClipRangesGatherRangesLengthsToOffsets(graph);
|
|
ClipRangesGatherSigridHash(graph);
|
|
#endif
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|