pytorch/torch/csrc/jit/passes/mkldnn_rewrite.cpp

220 lines
6.8 KiB
C++

#include <ATen/Config.h>
#include <ATen/code_template.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/mkldnn_rewrite.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
namespace torch::jit {
#if AT_MKLDNN_ENABLED()
static c10::VaryingShape<int64_t> getSizesOf(Node* n, size_t idx) {
auto tt = n->input(idx)->type()->cast<TensorType>();
return tt->sizes();
}
static void insertPrePackedConvOpForNode(Node* n) {
constexpr int POS_INPUT = 0;
constexpr int POS_WEIGHT = 1;
if (!tensorexpr::isContiguous(
n->input(POS_INPUT), at::MemoryFormat::ChannelsLast)) {
GRAPH_DEBUG(
"insertPrePackedConvOpForNode: input is not ChannelsLast contiguous");
return;
}
if (!tensorexpr::isContiguous(
n->input(POS_WEIGHT), at::MemoryFormat::ChannelsLast)) {
GRAPH_DEBUG(
"insertPrePackedConvOpForNode: weight is not ChannelsLast contiguous");
return;
}
// Leave depthwise conv2d to NNC
if (tensorexpr::conv2dIsSupportedJit(n)) {
GRAPH_DEBUG("insertPrePackedConvOpForNode: leave depthwise conv2d to NNC");
return;
}
WithInsertPoint guard(n);
auto graph = n->owningGraph();
auto input_sizes = getSizesOf(n, POS_INPUT);
IValue input_size_value(*input_sizes.concrete_sizes());
auto input_size = graph->insertConstant(input_size_value);
auto prepack_node = graph->create(
Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1);
// skip input value
for (const auto i : c10::irange(1, n->inputs().size())) {
Value* v = n->input(i);
prepack_node->addInput(v);
}
prepack_node->addInput(input_size);
auto attr = graph->insertConstant(IValue("none"));
prepack_node->addInput(attr);
prepack_node->output()->setType(
getCustomClass("__torch__.torch.classes.mkldnn.ConvOpContext"));
graph->insertNode(prepack_node);
auto prepack_conv = graph->insertNode(
graph->create(Symbol::fromQualString("mkldnn_prepacked::conv2d_run"), 1));
prepack_conv->addInput(n->input(0));
prepack_conv->addInput(prepack_node->output());
prepack_conv->output()->setType(n->output()->type()->cast<TensorType>());
n->output()->replaceAllUsesWith(prepack_conv->output());
}
static bool isTensorTypeCPU(Node* node) {
for (const auto& input : node->inputs()) {
auto type = input->type()->cast<TensorType>();
if (!type) {
continue;
}
auto device = type->device();
if (!device) {
return false;
}
if (!device->is_cpu()) {
return false;
}
}
return true;
}
static void insertPrePackedConvOp(Block* b) {
for (Node* n : b->nodes()) {
for (Block* b : n->blocks()) {
insertPrePackedConvOp(b);
}
if (n->kind() == aten::conv2d) {
if (isTensorTypeCPU(n)) {
insertPrePackedConvOpForNode(n);
}
}
}
EliminateDeadCode(b);
}
static void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
insertPrePackedConvOp(graph->block());
}
static void insertMkldnnPrePackedOps(std::shared_ptr<Graph>& graph) {
insertMkldnnPrePackedConv2dOp(graph);
}
static void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
auto conv_op_rstring = at::jit::CodeTemplate(R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
%packed_weight_bias = mkldnn_prepacked::conv2d_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%input_size, %dummy_attr)
%conv2d_res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
%res = aten::${op}(%conv2d_res)
return (%res))");
auto conv_op_fused_rstring = at::jit::CodeTemplate(R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
%attr: str = prim::Constant[value="${op_attr}"]()
%packed_weight_bias : __torch__.torch.classes.mkldnn.ConvOpContext = mkldnn_prepacked::conv2d_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%input_size, %attr)
%res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
return (%res))");
for (auto const& it : mkldnn::fusion_rewrite_map) {
std::string op = it.first;
if (op == std::string("none")) {
continue;
}
at::jit::TemplateEnv env;
env.s("op", op);
at::jit::TemplateEnv env_fused;
env_fused.s("op_attr", op);
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(
conv_op_rstring.format(env), conv_op_fused_rstring.format(env_fused));
auto filters = it.second;
rewriter.runOnGraph(graph, filters);
}
}
static void PrePackingOpsFolder(Block* b) {
auto is_foldable_op = [](const Node* n) -> bool {
return (
n->kind() ==
Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"));
};
std::unordered_set<Node*> nodes_to_delete;
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
PrePackingOpsFolder(block);
}
if (is_foldable_op(n)) {
auto optional_outputs = torch::jit::runNodeIfInputsAreConstant(n);
if (optional_outputs) {
auto outputs = optional_outputs.value();
TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output");
Value* prepack_op_value = n->output(0);
auto graph = n->owningGraph();
WithInsertPoint ins(prepack_op_value->node());
auto weak_class_obj =
outputs[0].toObject()->copy_to_weak_compilation_ref();
Value* packed_weight = graph->insertConstant(weak_class_obj)
->setType(n->output(0)->type());
prepack_op_value->replaceAllUsesWith(packed_weight);
nodes_to_delete.insert(n);
}
}
}
for (auto n : nodes_to_delete) {
n->removeAllInputs();
}
for (auto n : nodes_to_delete) {
n->destroy();
}
}
static void FoldPrePackingOps(std::shared_ptr<Graph>& graph) {
PrePackingOpsFolder(graph->block());
}
void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
GRAPH_DEBUG(
"Before insertMkldnnPrePackedOps. Beginning of FuseConvWithEltwise\n",
*graph);
insertMkldnnPrePackedOps(graph);
GRAPH_DEBUG(
"After insertMkldnnPrePackedOps, before FuseReluWithPackedOps\n", *graph);
FuseReluWithPackedOps(graph);
GRAPH_DEBUG(
"After FuseReluWithPackedOps, before FoldPrePackingOps\n", *graph);
FoldPrePackingOps(graph);
GRAPH_DEBUG("After FoldPrePackingOps. End of FuseConvWithEltwise\n", *graph);
}
#else
void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
GRAPH_DEBUG("MKLDNN Not enabled");
}
#endif // AT_MKLDNN_ENABLED()
} // namespace torch::jit