mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
add hardtanh(0,6) to the set of MKLDNN fusible ops for mobilenetv2 (#56203)
Summary: TODO: post the numbers for mobilenetv2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/56203 Reviewed By: malfet Differential Revision: D27917557 Pulled By: Krovatkin fbshipit-source-id: acea0f933a7e8c7a036a494295f68222c46a36f7
This commit is contained in:
parent
7b7a4750a9
commit
d6a25a58f5
|
|
@ -33,7 +33,7 @@ namespace c10 {
|
|||
_(prim, MKLDNNGroup) \
|
||||
_(prim, MKLDNNHardSwish) \
|
||||
_(prim, MKLDNNHardSigmoid) \
|
||||
_(prim, MKLDNNRelu6) \
|
||||
_(prim, MKLDNNHardTanh) \
|
||||
_(prim, Drop) \
|
||||
_(prim, Eval) \
|
||||
_(prim, Expand) /* onnx */ \
|
||||
|
|
@ -337,6 +337,7 @@ namespace c10 {
|
|||
_(aten, hardswish) \
|
||||
_(aten, hardswish_) \
|
||||
_(aten, hardsigmoid_) \
|
||||
_(aten, hardtanh_) \
|
||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||
_(onnx, Add) \
|
||||
_(onnx, Concat) \
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ allow_list = [
|
|||
# Internal
|
||||
("static", datetime.date(9999, 1, 1)),
|
||||
("prim::ModuleDictIndex", datetime.date(9999, 1, 1)),
|
||||
("prim::MKLDNNRelu6", datetime.date(9999, 1, 1)),
|
||||
("prim::MKLDNNRelu6_", datetime.date(9999, 1, 1)),
|
||||
# Internal, profiler-specific ops
|
||||
("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)),
|
||||
("profiler::_record_function_enter", datetime.date(9999, 1, 1)),
|
||||
|
|
|
|||
|
|
@ -1866,14 +1866,17 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
def test_conv_hardswish(self):
|
||||
with set_default_dtype(torch.float):
|
||||
activations = [
|
||||
torch.nn.Hardswish,
|
||||
torch.nn.Hardsigmoid,
|
||||
torch.nn.ReLU6,
|
||||
torch.nn.Hardswish(),
|
||||
torch.nn.Hardsigmoid(),
|
||||
torch.nn.ReLU6(),
|
||||
torch.nn.Hardtanh(0., 6.),
|
||||
torch.nn.Hardtanh(1., 100.),
|
||||
torch.nn.Hardtanh(-100., -1.),
|
||||
]
|
||||
|
||||
model = torchvision.models.resnet18()
|
||||
for activation in activations:
|
||||
sub_model = torch.nn.Sequential(model.conv1, activation())
|
||||
sub_model = torch.nn.Sequential(model.conv1, activation)
|
||||
sub_model.eval()
|
||||
mod = torch.jit.freeze(torch.jit.script(sub_model))
|
||||
N, C, H, W, = 10, 3, 224, 224
|
||||
|
|
@ -1887,7 +1890,6 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
op_map = {
|
||||
'prim::MKLDNNHardSwish' : F.hardswish,
|
||||
'prim::MKLDNNHardSigmoid' : F.hardsigmoid,
|
||||
'prim::MKLDNNRelu6' : F.relu6
|
||||
}
|
||||
|
||||
input_sizes = ([0], [1], [3], [1, 3, 8, 8])
|
||||
|
|
|
|||
|
|
@ -1289,6 +1289,7 @@ Node* Node::replaceWithNewSymbol(Symbol new_symbol) {
|
|||
v->replaceAllUsesWith(new_out);
|
||||
}
|
||||
replace_node->copyMetadata(this);
|
||||
replace_node->copyAttributes(*this);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(replace_node->maybeOperator() != nullptr) == had_operator,
|
||||
"invalid symbol replacement:",
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
#include <ATen/native/ConvUtils.h>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
|
||||
|
|
@ -180,7 +181,7 @@ void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) {
|
|||
auto k = node->kind();
|
||||
if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
|
||||
k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
|
||||
k == prim::MKLDNNRelu6) {
|
||||
k == prim::MKLDNNHardTanh) {
|
||||
if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -349,6 +350,15 @@ Operation BroadOp(const Node* node) {
|
|||
};
|
||||
}
|
||||
|
||||
static std::function<void(at::Tensor output, at::Tensor input)> hardtanh_helper(
|
||||
const Node* n) {
|
||||
auto min_val = n->f(attr::min_val);
|
||||
auto max_val = n->f(attr::max_val);
|
||||
return [min_val, max_val](at::Tensor output, at::Tensor input) {
|
||||
at::cpu::hardtanh_out(output, input, min_val, max_val);
|
||||
};
|
||||
}
|
||||
|
||||
// any op added to this registry needs to meet
|
||||
// the precondition: `aten_op(0) == 0`
|
||||
const RegisterOperators MKLDNNHardSwishOpReg({
|
||||
|
|
@ -369,12 +379,10 @@ const RegisterOperators MKLDNNHardSwishOpReg({
|
|||
true),
|
||||
AliasAnalysisKind::FROM_SCHEMA),
|
||||
torch::jit::Operator(
|
||||
"prim::MKLDNNRelu6_(Tensor(a!) self) -> Tensor(a!)",
|
||||
createUnaryOp(
|
||||
[](at::Tensor output, at::Tensor input) {
|
||||
at::cpu::hardtanh_out(output, input, 0.f, 6.f);
|
||||
},
|
||||
true),
|
||||
"prim::MKLDNNHardTanh_(Tensor(a!) self) -> Tensor(a!)",
|
||||
[](const Node* n) -> Operation {
|
||||
return createUnaryOp(hardtanh_helper(n), true);
|
||||
},
|
||||
AliasAnalysisKind::FROM_SCHEMA),
|
||||
torch::jit::Operator(
|
||||
"prim::MKLDNNHardSwish(Tensor a) -> Tensor",
|
||||
|
|
@ -393,12 +401,10 @@ const RegisterOperators MKLDNNHardSwishOpReg({
|
|||
false),
|
||||
AliasAnalysisKind::FROM_SCHEMA),
|
||||
torch::jit::Operator(
|
||||
"prim::MKLDNNRelu6(Tensor(a!) self) -> Tensor(a!)",
|
||||
createUnaryOp(
|
||||
[](at::Tensor output, at::Tensor input) {
|
||||
at::cpu::hardtanh_out(output, input, 0.f, 6.f);
|
||||
},
|
||||
false),
|
||||
"prim::MKLDNNHardTanh(Tensor self) -> Tensor",
|
||||
[](const Node* n) -> Operation {
|
||||
return createUnaryOp(hardtanh_helper(n), false);
|
||||
},
|
||||
AliasAnalysisKind::FROM_SCHEMA),
|
||||
});
|
||||
|
||||
|
|
@ -569,6 +575,25 @@ void moveWeightsToMKLDNN(Node* n) {
|
|||
}
|
||||
}
|
||||
|
||||
static void hartanh_node_creator(
|
||||
Node* body_node,
|
||||
double min_val,
|
||||
double max_val) {
|
||||
WithInsertPoint insert_guard{body_node};
|
||||
auto out_node = body_node->owningGraph()->create(
|
||||
{prim::MKLDNNHardTanh}, {body_node->input(0)}, 1);
|
||||
// N.B. we can't use `insert` as it calls `getOperation` (via
|
||||
// `emitBuiltinCall`) which uses `min_val` and `max_val` attrs which we
|
||||
// haven't set yet.
|
||||
body_node->owningGraph()->insertNode(out_node);
|
||||
auto out_val = out_node->output();
|
||||
out_node->f_(attr::min_val, min_val);
|
||||
out_node->f_(attr::max_val, max_val);
|
||||
out_val->copyMetadata(body_node->output());
|
||||
body_node->output()->replaceAllUsesWith(out_val);
|
||||
body_node->destroy();
|
||||
}
|
||||
|
||||
void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
|
||||
auto graph = subgraph_node->owningGraph();
|
||||
Value* none_value = nullptr;
|
||||
|
|
@ -633,8 +658,16 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
|
|||
}
|
||||
|
||||
if (body_node->kind() == aten::relu6) {
|
||||
body_node->replaceWithNewSymbol(prim::MKLDNNRelu6);
|
||||
body_node->destroy();
|
||||
hartanh_node_creator(body_node, 0., 6.);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (body_node->kind() == aten::hardtanh) {
|
||||
auto min_val =
|
||||
constant_as<double>(body_node->namedInput("min_val")).value();
|
||||
auto max_val =
|
||||
constant_as<double>(body_node->namedInput("max_val")).value();
|
||||
hartanh_node_creator(body_node, min_val, max_val);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -816,10 +849,19 @@ class MKLDNNSubgraphSlicer {
|
|||
// conversions. from initial testing including it speeds up models
|
||||
case aten::max_pool2d:
|
||||
case aten::max_pool3d:
|
||||
case aten::adaptive_avg_pool2d:
|
||||
return true;
|
||||
}
|
||||
|
||||
if (n->kind() == aten::hardtanh && !nonConstantParameters(n)) {
|
||||
auto min_val = constant_as<double>(n->namedInput("min_val")).value();
|
||||
auto max_val = constant_as<double>(n->namedInput("max_val")).value();
|
||||
// we need to maintain the following invariant `pointwise_func(0) == 0`,
|
||||
// see `createUnaryOp`
|
||||
if (min_val <= 0. && max_val >= 0.) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (n->kind() == aten::add || n->kind() == aten::mul) {
|
||||
// mkldnn doesn't currently support Tensor-Scalar add
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
|
|
@ -949,6 +991,7 @@ void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) {
|
|||
aten::dropout_,
|
||||
aten::sigmoid_,
|
||||
aten::hardsigmoid_,
|
||||
aten::hardtanh_,
|
||||
};
|
||||
return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user