mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Follows #132604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132753 Approved by: https://github.com/Skylion007
535 lines
23 KiB
C++
535 lines
23 KiB
C++
#include <ATen/core/jit_type.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/fold_conv_bn.h>
|
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
|
#include <torch/csrc/jit/passes/fuse_linear.h>
|
|
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
|
#include <torch/csrc/jit/passes/prepack_folding.h>
|
|
#include <torch/csrc/jit/passes/remove_dropout.h>
|
|
#include <torch/csrc/jit/passes/remove_mutation.h>
|
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
|
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor_impl.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
namespace {
|
|
|
|
void insertPrePackedBatchNormOp(std::shared_ptr<Graph>& graph) {
|
|
std::string batchnorm_pattern = R"(
|
|
graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
|
|
%r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
|
|
return (%r))";
|
|
std::string prepacked_ops_pattern = R"(
|
|
graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
|
|
%op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context(
|
|
%weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
|
|
%res = vulkan_prepack::run_batchnorm_context(%input, %op_context)
|
|
return (%res))";
|
|
|
|
SubgraphRewriter batchnorm_rewriter;
|
|
batchnorm_rewriter.RegisterRewritePattern(
|
|
batchnorm_pattern, prepacked_ops_pattern);
|
|
batchnorm_rewriter.runOnGraph(graph);
|
|
}
|
|
|
|
void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
|
|
// fuse decomposed linear into aten::linear
|
|
FuseLinear(graph);
|
|
|
|
std::string linear_pattern = R"(
|
|
graph(%input, %weight, %bias):
|
|
%r = aten::linear(%input, %weight, %bias)
|
|
return (%r))";
|
|
std::string prepacked_ops_pattern = R"(
|
|
graph(%input, %weight, %bias):
|
|
%weight_t = aten::t(%weight)
|
|
%packed_weight_bias = vulkan_prepack::create_linear_context(
|
|
%weight_t, %bias)
|
|
%res = vulkan_prepack::run_linear_context(%input, %packed_weight_bias)
|
|
return (%res))";
|
|
|
|
SubgraphRewriter linear_rewriter;
|
|
linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern);
|
|
linear_rewriter.runOnGraph(graph);
|
|
}
|
|
|
|
void insertPrePackedLayernormOp(std::shared_ptr<Graph>& graph) {
|
|
std::string layernorm_pattern = R"(
|
|
graph(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable):
|
|
%r = aten::layer_norm(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable)
|
|
return (%r))";
|
|
std::string prepacked_ops_pattern = R"(
|
|
graph(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable):
|
|
%op_context : __torch__.torch.classes.vulkan.LayernormPackedContext = vulkan_prepack::create_layernorm_context(
|
|
%weight, %bias, %eps)
|
|
%res = vulkan_prepack::run_layernorm_context(%input, %normalized_shape, %op_context)
|
|
return (%res))";
|
|
|
|
SubgraphRewriter layernorm_rewriter;
|
|
layernorm_rewriter.RegisterRewritePattern(
|
|
layernorm_pattern, prepacked_ops_pattern);
|
|
layernorm_rewriter.runOnGraph(graph);
|
|
}
|
|
|
|
void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
|
|
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
|
|
|
|
std::string conv_2d_pattern = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
|
%r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
|
|
return (%r) )";
|
|
|
|
std::string prepacked_ops_conv2d_pattern = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
|
%output_min_max : None = prim::Constant()
|
|
%packed_weight_bias = vulkan_prepack::create_conv2d_context(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%output_min_max, %output_min_max)
|
|
%r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
|
|
return (%r) )";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(
|
|
conv_2d_pattern, prepacked_ops_conv2d_pattern);
|
|
rewriter.runOnGraph(graph);
|
|
|
|
std::string conv_2d_transpose_pattern = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
|
|
%output_padding:int[], %groups:int):
|
|
%res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
|
|
return (%res) )";
|
|
|
|
std::string prepacked_ops_conv2d_transpose_pattern = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
|
|
%output_min_max : None = prim::Constant()
|
|
%packed_weight_bias = vulkan_prepack::create_tconv2d_context(
|
|
%weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
|
|
%output_min_max, %output_min_max)
|
|
%res = vulkan_prepack::run_tconv2d_context(%input, %packed_weight_bias)
|
|
return (%res) )";
|
|
|
|
SubgraphRewriter transpose_rewriter;
|
|
transpose_rewriter.RegisterRewritePattern(
|
|
conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern);
|
|
transpose_rewriter.runOnGraph(graph);
|
|
}
|
|
|
|
void insertPrePackedConv1dOp(std::shared_ptr<Graph>& graph) {
|
|
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
|
|
|
|
std::string conv_1d_pattern = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
|
%r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
|
|
return (%r) )";
|
|
|
|
std::string prepacked_ops_conv1d_pattern = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
|
%packed_weight_bias = vulkan_prepack::create_conv1d_context(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups)
|
|
%r = vulkan_prepack::run_conv1d_context(%input, %packed_weight_bias)
|
|
return (%r) )";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(
|
|
conv_1d_pattern, prepacked_ops_conv1d_pattern);
|
|
rewriter.runOnGraph(graph);
|
|
}
|
|
|
|
void transferInputOutputBackends(std::shared_ptr<Graph>& graph) {
|
|
// Move inputs to Vulkan backend
|
|
for (Value* input : graph->inputs()) {
|
|
NamedValue named_input = NamedValue("", input);
|
|
if (named_input.type()->kind() == TypeKind::TensorType &&
|
|
!input->uses().empty()) {
|
|
// find the insertion point
|
|
WithInsertPoint ip(input->uses()[0].user->prev());
|
|
Value* replaced_input = graph->insert(
|
|
Symbol::fromQualString("aten::to"), {named_input, "vulkan"});
|
|
// replace the input
|
|
input->replaceAllUsesAfterNodeWith(
|
|
replaced_input->node(), replaced_input);
|
|
}
|
|
}
|
|
|
|
// Move outputs to CPU backend
|
|
at::ArrayRef<Value*>&& outputs = graph->outputs();
|
|
for (size_t i = 0; i < outputs.size(); i++) {
|
|
Value* output = outputs[i];
|
|
NamedValue named_output = NamedValue("", output);
|
|
if (named_output.type()->kind() == TypeKind::TensorType) {
|
|
// find the insertion point
|
|
WithInsertPoint ip(output->node()->next());
|
|
Value* replaced_output = graph->insert(
|
|
Symbol::fromQualString("aten::to"), {named_output, "cpu"});
|
|
// replace the output
|
|
graph->block()->replaceOutput(i, replaced_output);
|
|
}
|
|
}
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.runOnGraph(graph);
|
|
}
|
|
|
|
void transferInputOutputBackends(script::Module& module) {
|
|
std::shared_ptr<Graph> graph = module.get_methods()[0].graph();
|
|
transferInputOutputBackends(graph);
|
|
}
|
|
|
|
void eliminateDeadCode(script::Module& module) {
|
|
for (auto& method : module.get_methods()) {
|
|
EliminateDeadCode(method.graph());
|
|
}
|
|
}
|
|
|
|
void rewriteQuantizedOps(std::shared_ptr<Graph>& graph) {
|
|
// quantized::add
|
|
std::string quantized_add_pattern = R"(
|
|
graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
|
|
%res = quantized::add(%a_quant, %b_quant, %r_scale, %r_zero_point)
|
|
return (%res) )";
|
|
std::string vk_quantized_add_pattern = R"(
|
|
graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
|
|
%res = vulkan_quantized::add(%a_quant, %b_quant, %r_scale, %r_zero_point)
|
|
return (%res) )";
|
|
|
|
torch::jit::SubgraphRewriter quantized_add_rewriter;
|
|
quantized_add_rewriter.RegisterRewritePattern(
|
|
quantized_add_pattern, vk_quantized_add_pattern);
|
|
quantized_add_rewriter.runOnGraph(graph);
|
|
|
|
// quantized::mul
|
|
std::string quantized_mul_pattern = R"(
|
|
graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
|
|
%res = quantized::mul(%a_quant, %b_quant, %r_scale, %r_zero_point)
|
|
return (%res) )";
|
|
std::string vk_quantized_mul_pattern = R"(
|
|
graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
|
|
%res = vulkan_quantized::mul(%a_quant, %b_quant, %r_scale, %r_zero_point)
|
|
return (%res) )";
|
|
|
|
torch::jit::SubgraphRewriter quantized_mul_rewriter;
|
|
quantized_mul_rewriter.RegisterRewritePattern(
|
|
quantized_mul_pattern, vk_quantized_mul_pattern);
|
|
quantized_mul_rewriter.runOnGraph(graph);
|
|
|
|
// quantized::conv2d
|
|
std::string quantized_conv2d_pattern = R"(
|
|
graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
|
|
%res = quantized::conv2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
|
|
return (%res) )";
|
|
std::string vk_quantized_conv2d_pattern = R"(
|
|
graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
|
|
%output_min_max : None = prim::Constant()
|
|
%vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qconv2d_context(
|
|
%packed_params, %output_min_max, %output_min_max)
|
|
%res = vulkan_prepack::run_qconv2d_context(
|
|
%a_quant, %r_scale, %r_zero_point, %vk_packed_params)
|
|
return (%res) )";
|
|
|
|
torch::jit::SubgraphRewriter quantized_conv2d_rewriter;
|
|
quantized_conv2d_rewriter.RegisterRewritePattern(
|
|
quantized_conv2d_pattern, vk_quantized_conv2d_pattern);
|
|
quantized_conv2d_rewriter.runOnGraph(graph);
|
|
|
|
// quantized::conv_transpose2d
|
|
std::string quantized_conv_transpose2d_pattern = R"(
|
|
graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
|
|
%res = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
|
|
return (%res) )";
|
|
std::string vk_quantized_conv_transpose2d_pattern = R"(
|
|
graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
|
|
%output_min_max : None = prim::Constant()
|
|
%vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qtconv2d_context(
|
|
%packed_params, %output_min_max, %output_min_max)
|
|
%res = vulkan_prepack::run_qconv2d_context(
|
|
%a_quant, %r_scale, %r_zero_point, %vk_packed_params)
|
|
return (%res) )";
|
|
|
|
torch::jit::SubgraphRewriter quantized_conv_transpose2d_rewriter;
|
|
quantized_conv_transpose2d_rewriter.RegisterRewritePattern(
|
|
quantized_conv_transpose2d_pattern,
|
|
vk_quantized_conv_transpose2d_pattern);
|
|
quantized_conv_transpose2d_rewriter.runOnGraph(graph);
|
|
|
|
// quantized::conv2d_relu
|
|
std::string quantized_conv2d_relu_pattern = R"(
|
|
graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
|
|
%res = quantized::conv2d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
|
|
return (%res) )";
|
|
std::string vk_quantized_conv2d_relu_pattern = R"(
|
|
graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
|
|
%output_min: float = prim::Constant[value=0.0]()
|
|
%output_max: None = prim::Constant()
|
|
%vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qconv2d_context(
|
|
%packed_params, %output_min, %output_max)
|
|
%res = vulkan_prepack::run_qconv2d_context(
|
|
%a_quant, %r_scale, %r_zero_point, %vk_packed_params)
|
|
return (%res) )";
|
|
|
|
torch::jit::SubgraphRewriter quantized_conv2d_relu_rewriter;
|
|
quantized_conv2d_relu_rewriter.RegisterRewritePattern(
|
|
quantized_conv2d_relu_pattern, vk_quantized_conv2d_relu_pattern);
|
|
quantized_conv2d_relu_rewriter.runOnGraph(graph);
|
|
|
|
// quantized::linear
|
|
std::string quantized_linear_pattern = R"(
|
|
graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
|
|
%res = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
|
|
return (%res) )";
|
|
std::string vk_quantized_linear_pattern = R"(
|
|
graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
|
|
%vk_packed_params : __torch__.torch.classes.vulkan.LinearPackedContext = vulkan_quantized_prepack::convert_linear_context(
|
|
%packed_params)
|
|
%res = vulkan_prepack::run_qlinear_context(
|
|
%a_quant, %r_scale, %r_zero_point, %vk_packed_params)
|
|
return (%res) )";
|
|
|
|
torch::jit::SubgraphRewriter quantized_linear_rewriter;
|
|
quantized_linear_rewriter.RegisterRewritePattern(
|
|
quantized_linear_pattern, vk_quantized_linear_pattern);
|
|
quantized_linear_rewriter.runOnGraph(graph);
|
|
}
|
|
|
|
void insertPrePackedGruOp(std::shared_ptr<Graph>& graph) {
|
|
std::string gru_pattern = R"(
|
|
graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
|
|
%y.1 : Tensor, %hn.1 : Tensor = aten::gru(%input.1, %hx.1, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
|
|
return (%y.1, %hn.1) )";
|
|
std::string prepacked_ops_pattern = R"(
|
|
graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
|
|
%packed_weights_biases = vulkan_prepack::create_gru_context(
|
|
%params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
|
|
%y.1 : Tensor, %hn.1 : Tensor = vulkan_prepack::run_gru_context(%input.1, %hx.1, %packed_weights_biases)
|
|
return (%y.1, %hn.1) )";
|
|
|
|
auto filter = [&](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
auto node = match.values_map.at(vmap.at("params_cpu"))->node();
|
|
return node->output()->type()->str() == "Tensor[]";
|
|
};
|
|
|
|
SubgraphRewriter gru_rewriter;
|
|
gru_rewriter.RegisterRewritePattern(gru_pattern, prepacked_ops_pattern);
|
|
gru_rewriter.runOnGraph(graph, filter);
|
|
}
|
|
|
|
void insertPrePackedLstmOp(std::shared_ptr<Graph>& graph) {
|
|
std::string lstm_pattern = R"(
|
|
graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
|
|
%y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = aten::lstm(%input.1, %hx, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
|
|
return (%y.1, %hn.1, %cn.1) )";
|
|
std::string prepacked_ops_pattern = R"(
|
|
graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
|
|
%packed_weights_biases = vulkan_prepack::create_lstm_context(
|
|
%params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
|
|
%hx.1 : Tensor, %cx.1 : Tensor = prim::ListUnpack(%hx)
|
|
%y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = vulkan_prepack::run_lstm_context(%input.1, %hx.1, %cx.1, %packed_weights_biases)
|
|
return (%y.1, %hn.1, %cn.1) )";
|
|
|
|
auto filter = [&](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
auto node = match.values_map.at(vmap.at("hx"))->node();
|
|
return node->output()->type()->str() == "Tensor[]";
|
|
};
|
|
|
|
SubgraphRewriter lstm_rewriter;
|
|
lstm_rewriter.RegisterRewritePattern(lstm_pattern, prepacked_ops_pattern);
|
|
lstm_rewriter.runOnGraph(graph, filter);
|
|
}
|
|
|
|
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
|
|
SubgraphRewriter rewriter;
|
|
|
|
std::string conv2d_prepack_run_hardtanh_fused = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
|
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
|
%packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%output_min, %output_max)
|
|
%r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
|
|
return (%r) )";
|
|
|
|
std::string conv2d_prepack_run_hardtanh = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
|
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
|
%packed_weight_bias = vulkan_prepack::create_conv2d_context(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%dummy_min_max, %dummy_min_max)
|
|
%conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
|
|
%r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
|
|
return (%r) )";
|
|
|
|
rewriter.RegisterRewritePattern(
|
|
conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
|
|
|
|
std::string conv2d_prepack_run_hardtanh_inplace = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
|
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
|
%packed_weight_bias = vulkan_prepack::create_conv2d_context(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%dummy_min_max, %dummy_min_max)
|
|
%conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
|
|
%r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
|
|
return (%r) )";
|
|
|
|
rewriter.RegisterRewritePattern(
|
|
conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
|
|
|
|
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
|
|
}
|
|
|
|
void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
|
|
SubgraphRewriter rewriter;
|
|
|
|
std::string conv2d_prepack_run_relu_fused = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
|
%dilation:int[], %groups:int, %dummy_min_max):
|
|
%output_min: float = prim::Constant[value=0.0]()
|
|
%output_max: None = prim::Constant()
|
|
%packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%output_min, %output_max)
|
|
%r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
|
|
return (%r) )";
|
|
|
|
std::string conv2d_prepack_run_relu = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
|
%dilation:int[], %groups:int, %dummy_min_max):
|
|
%packed_weight_bias = vulkan_prepack::create_conv2d_context(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%dummy_min_max, %dummy_min_max)
|
|
%conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
|
|
%r = aten::relu(%conv2d_res)
|
|
return (%r) )";
|
|
|
|
rewriter.RegisterRewritePattern(
|
|
conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
|
|
|
|
std::string conv2d_prepack_run_relu_inplace = R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
|
%dilation:int[], %groups:int, %dummy_min_max):
|
|
%packed_weight_bias = vulkan_prepack::create_conv2d_context(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%dummy_min_max, %dummy_min_max)
|
|
%conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
|
|
%r = aten::relu_(%conv2d_res)
|
|
return (%r) )";
|
|
|
|
rewriter.RegisterRewritePattern(
|
|
conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
|
|
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
|
|
insertPrePackedLinearOp(graph);
|
|
insertPrePackedLayernormOp(graph);
|
|
insertPrePackedConv2dOp(graph);
|
|
insertPrePackedConv1dOp(graph);
|
|
rewriteQuantizedOps(graph);
|
|
insertPrePackedGruOp(graph);
|
|
insertPrePackedLstmOp(graph);
|
|
insertPrePackedBatchNormOp(graph);
|
|
}
|
|
|
|
void vulkanInsertPrePackedOps(script::Module& module) {
|
|
for (auto& method : module.get_methods()) {
|
|
auto graph = method.graph();
|
|
vulkanInsertPrePackedOps(graph);
|
|
}
|
|
for (script::Module m : module.children()) {
|
|
vulkanInsertPrePackedOps(m);
|
|
}
|
|
}
|
|
|
|
void vulkanFusePrePackedConvWithClamp(script::Module& module) {
|
|
auto graph = module.get_method("forward").graph();
|
|
fuseReluWithPackedOps(graph);
|
|
fuseHardtanhWithPackedOps(graph);
|
|
}
|
|
|
|
void vulkanFoldPrePackingOps(script::Module& m) {
|
|
PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
|
|
return (
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_conv2d_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_tconv2d_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_qconv2d_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_qtconv2d_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString(
|
|
"vulkan_quantized_prepack::convert_qconv2d_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_conv1d_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString(
|
|
"vulkan_quantized_prepack::convert_qtconv2d_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString(
|
|
"vulkan_quantized_prepack::convert_linear_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_linear_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_layernorm_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_gru_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_lstm_context")) ||
|
|
(n->kind() ==
|
|
Symbol::fromQualString("vulkan_prepack::create_batchnorm_context")));
|
|
};
|
|
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
|
|
}
|
|
|
|
static void vulkanRemoveMutation(script::Module& module) {
|
|
auto graph = module.get_method("forward").graph();
|
|
RemoveTensorMutation(graph);
|
|
}
|
|
|
|
static void vulkanRunCanonicalOptimizations(script::Module& module) {
|
|
auto graph = module.get_method("forward").graph();
|
|
for (const auto& method : module.get_methods()) {
|
|
auto method_graph = method.graph();
|
|
runOptimization(method_graph, false /* no loop unrolling */);
|
|
}
|
|
}
|
|
|
|
script::Module vulkanOptimizeForMobile(
|
|
const script::Module& m,
|
|
const std::set<MobileOptimizerType>& optimization_blocklist,
|
|
const std::vector<std::string>& preserved_methods) {
|
|
auto cloned_module = m.clone();
|
|
cloned_module.eval();
|
|
cloned_module = FoldConvBatchNorm(cloned_module);
|
|
cloned_module = freeze_module(cloned_module, preserved_methods);
|
|
vulkanInsertPrePackedOps(cloned_module);
|
|
vulkanFusePrePackedConvWithClamp(cloned_module);
|
|
vulkanFoldPrePackingOps(cloned_module);
|
|
removeDropout(cloned_module);
|
|
vulkanRemoveMutation(cloned_module);
|
|
|
|
if (!optimization_blocklist.count(
|
|
MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) {
|
|
transferInputOutputBackends(cloned_module);
|
|
cloned_module.register_attribute(
|
|
"requires_backend_transfers", BoolType::get(), false);
|
|
}
|
|
|
|
// remove duplicated constants
|
|
vulkanRunCanonicalOptimizations(cloned_module);
|
|
eliminateDeadCode(cloned_module);
|
|
|
|
cloned_module.register_attribute(
|
|
"optimized_for_vulkan", BoolType::get(), true);
|
|
return cloned_module;
|
|
}
|
|
|
|
} // namespace torch::jit
|