mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[SR] Fuse quantized linear/relu (#75775)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75775 fbgemm kernels already implement the fused kernel, no reason not to use it ghstack-source-id: 155450342 Test Plan: New unit tests Reviewed By: navahgar Differential Revision: D35633297 fbshipit-source-id: a744a33a65ce7dbb9ce8900dbe091b6d56dd4e48 (cherry picked from commit b1361b349862715aa17e6318c5e658cd6401a464)
This commit is contained in:
parent
7dc1383101
commit
fc64dbdc01
|
|
@ -3274,3 +3274,27 @@ TEST(StaticRuntime, NestedBlockIfReturnList) {
|
|||
at::randn({42, 42}), at::randn({42, 42}), true, false};
|
||||
testStaticRuntime(src, args1, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, QuantizedLinearDynamicFp16ReluFusion) {
|
||||
const auto src = R"IR(
|
||||
graph(%input: Tensor, %weights: Tensor):
|
||||
%bias: None = prim::Constant()
|
||||
%packed_params = quantized::linear_prepack_fp16(%weights, %bias)
|
||||
%x = quantized::linear_dynamic_fp16(%input, %packed_params)
|
||||
%y = aten::relu(%x)
|
||||
%ret = aten::clone(%y, %bias)
|
||||
return (%ret)
|
||||
)IR";
|
||||
at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
|
||||
at::Tensor input = torch::randn({3, 2}, torch::kFloat);
|
||||
|
||||
at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
|
||||
at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
|
||||
|
||||
testStaticRuntime(src, {input, weight}, {input_2, weight_2});
|
||||
|
||||
auto graph = getGraphFromIR(src);
|
||||
QuantizedLinearReluFusion(graph);
|
||||
EXPECT_FALSE(hasNodeWithKind(graph, "quantized::linear_dynamic_fp16"));
|
||||
EXPECT_TRUE(hasNodeWithKind(graph, "quantized::linear_relu_dynamic_fp16"));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -177,6 +177,7 @@ void OptimizeGraph(
|
|||
graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
|
||||
AddIfThenElseOp(graph);
|
||||
UseSplitAndSqueeze(graph);
|
||||
QuantizedLinearReluFusion(graph);
|
||||
GRAPH_DUMP("Final graph after optimizations: ", graph);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1229,5 +1229,20 @@ void EliminateNoOpSlice(std::shared_ptr<Graph>& graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph) {
|
||||
std::string pattern = R"IR(
|
||||
graph(%input, %packed_params):
|
||||
%x : Tensor = quantized::linear_dynamic_fp16(%input, %packed_params)
|
||||
%y : Tensor = aten::relu(%x)
|
||||
return (%y))IR";
|
||||
std::string fused_pattern = R"IR(
|
||||
graph(%input, %packed_params):
|
||||
%x : Tensor = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
|
||||
return (%x))IR";
|
||||
SubgraphRewriter fuse;
|
||||
fuse.RegisterRewritePattern(pattern, fused_pattern);
|
||||
fuse.runOnGraph(graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -76,5 +76,7 @@ TORCH_API void RemoveUnnecessaryOutputs(std::shared_ptr<Graph>& graph);
|
|||
TORCH_API void RemoveUnnecessaryEmbeddingBagOutputs(
|
||||
std::shared_ptr<Graph>& graph);
|
||||
|
||||
TORCH_API void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user