add hardtanh(0,6) to the set of MKLDNN fusible ops for mobilenetv2 (#56203)

Summary:
TODO: post the numbers for mobilenetv2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56203

Reviewed By: malfet

Differential Revision: D27917557

Pulled By: Krovatkin

fbshipit-source-id: acea0f933a7e8c7a036a494295f68222c46a36f7
This commit is contained in:
Nikolay Korovaiko 2021-04-23 08:04:37 -07:00 committed by Facebook GitHub Bot
parent 7b7a4750a9
commit d6a25a58f5
5 changed files with 71 additions and 22 deletions

View File

@ -33,7 +33,7 @@ namespace c10 {
_(prim, MKLDNNGroup) \
_(prim, MKLDNNHardSwish) \
_(prim, MKLDNNHardSigmoid) \
_(prim, MKLDNNRelu6) \
_(prim, MKLDNNHardTanh) \
_(prim, Drop) \
_(prim, Eval) \
_(prim, Expand) /* onnx */ \
@ -337,6 +337,7 @@ namespace c10 {
_(aten, hardswish) \
_(aten, hardswish_) \
_(aten, hardsigmoid_) \
_(aten, hardtanh_) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \

View File

@ -30,6 +30,8 @@ allow_list = [
# Internal
("static", datetime.date(9999, 1, 1)),
("prim::ModuleDictIndex", datetime.date(9999, 1, 1)),
("prim::MKLDNNRelu6", datetime.date(9999, 1, 1)),
("prim::MKLDNNRelu6_", datetime.date(9999, 1, 1)),
# Internal, profiler-specific ops
("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)),
("profiler::_record_function_enter", datetime.date(9999, 1, 1)),

View File

@ -1866,14 +1866,17 @@ class TestFrozenOptimizations(JitTestCase):
def test_conv_hardswish(self):
with set_default_dtype(torch.float):
activations = [
torch.nn.Hardswish,
torch.nn.Hardsigmoid,
torch.nn.ReLU6,
torch.nn.Hardswish(),
torch.nn.Hardsigmoid(),
torch.nn.ReLU6(),
torch.nn.Hardtanh(0., 6.),
torch.nn.Hardtanh(1., 100.),
torch.nn.Hardtanh(-100., -1.),
]
model = torchvision.models.resnet18()
for activation in activations:
sub_model = torch.nn.Sequential(model.conv1, activation())
sub_model = torch.nn.Sequential(model.conv1, activation)
sub_model.eval()
mod = torch.jit.freeze(torch.jit.script(sub_model))
N, C, H, W, = 10, 3, 224, 224
@ -1887,7 +1890,6 @@ class TestFrozenOptimizations(JitTestCase):
op_map = {
'prim::MKLDNNHardSwish' : F.hardswish,
'prim::MKLDNNHardSigmoid' : F.hardsigmoid,
'prim::MKLDNNRelu6' : F.relu6
}
input_sizes = ([0], [1], [3], [1, 3, 8, 8])

View File

@ -1289,6 +1289,7 @@ Node* Node::replaceWithNewSymbol(Symbol new_symbol) {
v->replaceAllUsesWith(new_out);
}
replace_node->copyMetadata(this);
replace_node->copyAttributes(*this);
TORCH_INTERNAL_ASSERT(
(replace_node->maybeOperator() != nullptr) == had_operator,
"invalid symbol replacement:",

View File

@ -25,6 +25,7 @@
#include <ATen/native/ConvUtils.h>
#include <algorithm>
#include <memory>
#include <ATen/core/stack.h>
#include <c10/core/Layout.h>
#include <c10/util/StringUtil.h>
@ -180,7 +181,7 @@ void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) {
auto k = node->kind();
if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
k == prim::MKLDNNRelu6) {
k == prim::MKLDNNHardTanh) {
if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
continue;
}
@ -349,6 +350,15 @@ Operation BroadOp(const Node* node) {
};
}
static std::function<void(at::Tensor output, at::Tensor input)> hardtanh_helper(
const Node* n) {
auto min_val = n->f(attr::min_val);
auto max_val = n->f(attr::max_val);
return [min_val, max_val](at::Tensor output, at::Tensor input) {
at::cpu::hardtanh_out(output, input, min_val, max_val);
};
}
// any op added to this registry needs to meet
// the precondition: `aten_op(0) == 0`
const RegisterOperators MKLDNNHardSwishOpReg({
@ -369,12 +379,10 @@ const RegisterOperators MKLDNNHardSwishOpReg({
true),
AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::MKLDNNRelu6_(Tensor(a!) self) -> Tensor(a!)",
createUnaryOp(
[](at::Tensor output, at::Tensor input) {
at::cpu::hardtanh_out(output, input, 0.f, 6.f);
},
true),
"prim::MKLDNNHardTanh_(Tensor(a!) self) -> Tensor(a!)",
[](const Node* n) -> Operation {
return createUnaryOp(hardtanh_helper(n), true);
},
AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::MKLDNNHardSwish(Tensor a) -> Tensor",
@ -393,12 +401,10 @@ const RegisterOperators MKLDNNHardSwishOpReg({
false),
AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::MKLDNNRelu6(Tensor(a!) self) -> Tensor(a!)",
createUnaryOp(
[](at::Tensor output, at::Tensor input) {
at::cpu::hardtanh_out(output, input, 0.f, 6.f);
},
false),
"prim::MKLDNNHardTanh(Tensor self) -> Tensor",
[](const Node* n) -> Operation {
return createUnaryOp(hardtanh_helper(n), false);
},
AliasAnalysisKind::FROM_SCHEMA),
});
@ -569,6 +575,25 @@ void moveWeightsToMKLDNN(Node* n) {
}
}
static void hartanh_node_creator(
Node* body_node,
double min_val,
double max_val) {
WithInsertPoint insert_guard{body_node};
auto out_node = body_node->owningGraph()->create(
{prim::MKLDNNHardTanh}, {body_node->input(0)}, 1);
// N.B. we can't use `insert` as it calls `getOperation` (via
// `emitBuiltinCall`) which uses `min_val` and `max_val` attrs which we
// haven't set yet.
body_node->owningGraph()->insertNode(out_node);
auto out_val = out_node->output();
out_node->f_(attr::min_val, min_val);
out_node->f_(attr::max_val, max_val);
out_val->copyMetadata(body_node->output());
body_node->output()->replaceAllUsesWith(out_val);
body_node->destroy();
}
void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
auto graph = subgraph_node->owningGraph();
Value* none_value = nullptr;
@ -633,8 +658,16 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
}
if (body_node->kind() == aten::relu6) {
body_node->replaceWithNewSymbol(prim::MKLDNNRelu6);
body_node->destroy();
hartanh_node_creator(body_node, 0., 6.);
continue;
}
if (body_node->kind() == aten::hardtanh) {
auto min_val =
constant_as<double>(body_node->namedInput("min_val")).value();
auto max_val =
constant_as<double>(body_node->namedInput("max_val")).value();
hartanh_node_creator(body_node, min_val, max_val);
continue;
}
@ -816,10 +849,19 @@ class MKLDNNSubgraphSlicer {
// conversions. from initial testing including it speeds up models
case aten::max_pool2d:
case aten::max_pool3d:
case aten::adaptive_avg_pool2d:
return true;
}
if (n->kind() == aten::hardtanh && !nonConstantParameters(n)) {
auto min_val = constant_as<double>(n->namedInput("min_val")).value();
auto max_val = constant_as<double>(n->namedInput("max_val")).value();
// we need to maintain the following invariant `pointwise_func(0) == 0`,
// see `createUnaryOp`
if (min_val <= 0. && max_val >= 0.) {
return true;
}
}
if (n->kind() == aten::add || n->kind() == aten::mul) {
// mkldnn doesn't currently support Tensor-Scalar add
for (size_t i = 0; i < 2; i++) {
@ -949,6 +991,7 @@ void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) {
aten::dropout_,
aten::sigmoid_,
aten::hardsigmoid_,
aten::hardtanh_,
};
return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
});