mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Follows #132604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132753 Approved by: https://github.com/Skylion007
220 lines
6.8 KiB
C++
220 lines
6.8 KiB
C++
#include <ATen/Config.h>
|
|
#include <ATen/code_template.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
|
#include <torch/csrc/jit/passes/mkldnn_rewrite.h>
|
|
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
#if AT_MKLDNN_ENABLED()
|
|
|
|
static c10::VaryingShape<int64_t> getSizesOf(Node* n, size_t idx) {
|
|
auto tt = n->input(idx)->type()->cast<TensorType>();
|
|
return tt->sizes();
|
|
}
|
|
|
|
static void insertPrePackedConvOpForNode(Node* n) {
|
|
constexpr int POS_INPUT = 0;
|
|
constexpr int POS_WEIGHT = 1;
|
|
if (!tensorexpr::isContiguous(
|
|
n->input(POS_INPUT), at::MemoryFormat::ChannelsLast)) {
|
|
GRAPH_DEBUG(
|
|
"insertPrePackedConvOpForNode: input is not ChannelsLast contiguous");
|
|
return;
|
|
}
|
|
|
|
if (!tensorexpr::isContiguous(
|
|
n->input(POS_WEIGHT), at::MemoryFormat::ChannelsLast)) {
|
|
GRAPH_DEBUG(
|
|
"insertPrePackedConvOpForNode: weight is not ChannelsLast contiguous");
|
|
return;
|
|
}
|
|
|
|
// Leave depthwise conv2d to NNC
|
|
if (tensorexpr::conv2dIsSupportedJit(n)) {
|
|
GRAPH_DEBUG("insertPrePackedConvOpForNode: leave depthwise conv2d to NNC");
|
|
return;
|
|
}
|
|
|
|
WithInsertPoint guard(n);
|
|
auto graph = n->owningGraph();
|
|
|
|
auto input_sizes = getSizesOf(n, POS_INPUT);
|
|
IValue input_size_value(*input_sizes.concrete_sizes());
|
|
auto input_size = graph->insertConstant(input_size_value);
|
|
|
|
auto prepack_node = graph->create(
|
|
Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1);
|
|
|
|
// skip input value
|
|
for (const auto i : c10::irange(1, n->inputs().size())) {
|
|
Value* v = n->input(i);
|
|
prepack_node->addInput(v);
|
|
}
|
|
prepack_node->addInput(input_size);
|
|
auto attr = graph->insertConstant(IValue("none"));
|
|
prepack_node->addInput(attr);
|
|
prepack_node->output()->setType(
|
|
getCustomClass("__torch__.torch.classes.mkldnn.ConvOpContext"));
|
|
graph->insertNode(prepack_node);
|
|
|
|
auto prepack_conv = graph->insertNode(
|
|
graph->create(Symbol::fromQualString("mkldnn_prepacked::conv2d_run"), 1));
|
|
prepack_conv->addInput(n->input(0));
|
|
prepack_conv->addInput(prepack_node->output());
|
|
prepack_conv->output()->setType(n->output()->type()->cast<TensorType>());
|
|
|
|
n->output()->replaceAllUsesWith(prepack_conv->output());
|
|
}
|
|
|
|
static bool isTensorTypeCPU(Node* node) {
|
|
for (const auto& input : node->inputs()) {
|
|
auto type = input->type()->cast<TensorType>();
|
|
if (!type) {
|
|
continue;
|
|
}
|
|
auto device = type->device();
|
|
if (!device) {
|
|
return false;
|
|
}
|
|
if (!device->is_cpu()) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static void insertPrePackedConvOp(Block* b) {
|
|
for (Node* n : b->nodes()) {
|
|
for (Block* b : n->blocks()) {
|
|
insertPrePackedConvOp(b);
|
|
}
|
|
|
|
if (n->kind() == aten::conv2d) {
|
|
if (isTensorTypeCPU(n)) {
|
|
insertPrePackedConvOpForNode(n);
|
|
}
|
|
}
|
|
}
|
|
EliminateDeadCode(b);
|
|
}
|
|
|
|
static void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
|
|
insertPrePackedConvOp(graph->block());
|
|
}
|
|
|
|
static void insertMkldnnPrePackedOps(std::shared_ptr<Graph>& graph) {
|
|
insertMkldnnPrePackedConv2dOp(graph);
|
|
}
|
|
|
|
static void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
|
|
auto conv_op_rstring = at::jit::CodeTemplate(R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
|
%dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
|
|
%packed_weight_bias = mkldnn_prepacked::conv2d_prepack(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%input_size, %dummy_attr)
|
|
%conv2d_res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
|
|
%res = aten::${op}(%conv2d_res)
|
|
return (%res))");
|
|
|
|
auto conv_op_fused_rstring = at::jit::CodeTemplate(R"(
|
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
|
%dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
|
|
%attr: str = prim::Constant[value="${op_attr}"]()
|
|
%packed_weight_bias : __torch__.torch.classes.mkldnn.ConvOpContext = mkldnn_prepacked::conv2d_prepack(
|
|
%weight, %bias, %stride, %padding, %dilation, %groups,
|
|
%input_size, %attr)
|
|
%res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
|
|
return (%res))");
|
|
|
|
for (auto const& it : mkldnn::fusion_rewrite_map) {
|
|
std::string op = it.first;
|
|
if (op == std::string("none")) {
|
|
continue;
|
|
}
|
|
|
|
at::jit::TemplateEnv env;
|
|
env.s("op", op);
|
|
|
|
at::jit::TemplateEnv env_fused;
|
|
env_fused.s("op_attr", op);
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(
|
|
conv_op_rstring.format(env), conv_op_fused_rstring.format(env_fused));
|
|
|
|
auto filters = it.second;
|
|
rewriter.runOnGraph(graph, filters);
|
|
}
|
|
}
|
|
|
|
static void PrePackingOpsFolder(Block* b) {
|
|
auto is_foldable_op = [](const Node* n) -> bool {
|
|
return (
|
|
n->kind() ==
|
|
Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"));
|
|
};
|
|
|
|
std::unordered_set<Node*> nodes_to_delete;
|
|
for (Node* n : b->nodes()) {
|
|
for (Block* block : n->blocks()) {
|
|
PrePackingOpsFolder(block);
|
|
}
|
|
if (is_foldable_op(n)) {
|
|
auto optional_outputs = torch::jit::runNodeIfInputsAreConstant(n);
|
|
if (optional_outputs) {
|
|
auto outputs = optional_outputs.value();
|
|
TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output");
|
|
Value* prepack_op_value = n->output(0);
|
|
auto graph = n->owningGraph();
|
|
WithInsertPoint ins(prepack_op_value->node());
|
|
auto weak_class_obj =
|
|
outputs[0].toObject()->copy_to_weak_compilation_ref();
|
|
Value* packed_weight = graph->insertConstant(weak_class_obj)
|
|
->setType(n->output(0)->type());
|
|
prepack_op_value->replaceAllUsesWith(packed_weight);
|
|
nodes_to_delete.insert(n);
|
|
}
|
|
}
|
|
}
|
|
for (auto n : nodes_to_delete) {
|
|
n->removeAllInputs();
|
|
}
|
|
for (auto n : nodes_to_delete) {
|
|
n->destroy();
|
|
}
|
|
}
|
|
|
|
static void FoldPrePackingOps(std::shared_ptr<Graph>& graph) {
|
|
PrePackingOpsFolder(graph->block());
|
|
}
|
|
|
|
void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
|
|
GRAPH_DEBUG(
|
|
"Before insertMkldnnPrePackedOps. Beginning of FuseConvWithEltwise\n",
|
|
*graph);
|
|
insertMkldnnPrePackedOps(graph);
|
|
GRAPH_DEBUG(
|
|
"After insertMkldnnPrePackedOps, before FuseReluWithPackedOps\n", *graph);
|
|
FuseReluWithPackedOps(graph);
|
|
GRAPH_DEBUG(
|
|
"After FuseReluWithPackedOps, before FoldPrePackingOps\n", *graph);
|
|
FoldPrePackingOps(graph);
|
|
GRAPH_DEBUG("After FoldPrePackingOps. End of FuseConvWithEltwise\n", *graph);
|
|
}
|
|
|
|
#else
|
|
|
|
void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
|
|
GRAPH_DEBUG("MKLDNN Not enabled");
|
|
}
|
|
|
|
#endif // AT_MKLDNN_ENABLED()
|
|
|
|
} // namespace torch::jit
|