mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[StaticRuntime] Support a new pattern (aten::to with 5 inputs) for ClipRangesToGatherToOffsets (#147189)
Summary: Support the following new pattern for ClipRangesToGatherToOffsets: Before optimization: ``` %11175 : Tensor, %11176 : Tensor = fb::clip_ranges_gather(%int_66.1, %getitem_1784.1, %347) %getattr_256.1 : int = prim::dtype(%11175) %to_298.1 : Tensor = aten::to(%11176, %getattr_256.1, %13, %13, %12) %lengths_to_offsets_333.1 : Tensor = fb::lengths_to_offsets(%to_298.1, %8) ``` After optimization: ``` %11199 : int = prim::dtype(%int_66.1) %11200 : Tensor, %11201 : Tensor = fb::clip_ranges_gather_to_offsets(%int_66.1, %getitem_1784.1, %347, %8, %11199) ``` It is similar with https://github.com/pytorch/pytorch/pull/146931, but aten::to has 5 inputs instead of 4. Differential Revision: D69627793 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147189 Approved by: https://github.com/hanyilou123
This commit is contained in:
parent
5c0c99f658
commit
a8fa4bcfd2
|
|
@ -260,19 +260,34 @@ namespace {
|
|||
[[maybe_unused]] void ClipRangesToGatherToOffsetsV2(
|
||||
std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
std::string pattern = R"IR(
|
||||
graph(%a, %b, %c, %d, %to0_in0, %to0_in1):
|
||||
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
|
||||
%y0_type : int = prim::dtype(%y0)
|
||||
%y2 : Tensor = aten::to(%y1, %y0_type, %to0_in0, %to0_in0, %to0_in1)
|
||||
%y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
|
||||
return (%y3, %y0))IR";
|
||||
std::string fused_pattern = R"IR(
|
||||
graph(%a, %b, %c, %d, %to0_in0, %to0_in1):
|
||||
%a_type : int = prim::dtype(%a)
|
||||
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_to_offsets(%a, %b, %c, %d, %a_type)
|
||||
return (%y1, %y0))IR";
|
||||
SubgraphRewriter fuse;
|
||||
fuse.RegisterRewritePattern(pattern, fused_pattern);
|
||||
fuse.runOnGraph(graph);
|
||||
|
||||
std::string pattern2 = R"IR(
|
||||
graph(%a, %b, %c, %d, %to0_in0):
|
||||
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
|
||||
%y0_type : int = prim::dtype(%y0)
|
||||
%y2 : Tensor = aten::to(%y1, %y0_type, %to0_in0, %to0_in0)
|
||||
%y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
|
||||
return (%y3, %y0))IR";
|
||||
std::string fused_pattern = R"IR(
|
||||
std::string fused_pattern2 = R"IR(
|
||||
graph(%a, %b, %c, %d, %to0_in0):
|
||||
%a_type : int = prim::dtype(%a)
|
||||
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_to_offsets(%a, %b, %c, %d, %a_type)
|
||||
return (%y1, %y0))IR";
|
||||
SubgraphRewriter fuse;
|
||||
fuse.RegisterRewritePattern(pattern, fused_pattern);
|
||||
fuse.RegisterRewritePattern(pattern2, fused_pattern2);
|
||||
fuse.runOnGraph(graph);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user