mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
This is reland of #87603 with definitions of c10::stoXX kept for further investigation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109566 Approved by: https://github.com/huydhn
133 lines
3.7 KiB
C++
133 lines
3.7 KiB
C++
#include "caffe2/opt/mobile.h"
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/opt/converter.h"
|
|
#include "caffe2/opt/fusion.h"
|
|
#include "caffe2/opt/passes.h"
|
|
|
|
namespace caffe2 {
|
|
namespace opt {
|
|
|
|
using namespace nom;
|
|
|
|
void addNNPACK(repr::NNModule* nn, bool low_memory) {
|
|
for (auto node : nn->dataFlow.getMutableNodes()) {
|
|
// Skip blobs.
|
|
NOM_REQUIRE_OR_CONT(repr::nn::is<repr::NeuralNetOperator>(node));
|
|
|
|
// Check if it is a convolution.
|
|
auto nnOp = repr::nn::get<repr::NeuralNetOperator>(node);
|
|
NOM_REQUIRE_OR_CONT(isa<nom::repr::Conv>(nnOp));
|
|
|
|
// Requires X, W, b for NNPACK
|
|
NOM_REQUIRE_OR_CONT(node->getInEdges().size() >= 3);
|
|
|
|
std::string engine = "NNPACK";
|
|
|
|
// Now do some specific checks to see if an NNPACK engine is correct.
|
|
bool validTransformCandidate = true;
|
|
auto conv = dyn_cast<nom::repr::Conv>(nnOp);
|
|
|
|
NOM_REQUIRE_OR_CONT(conv->getLayout() == nom::repr::Conv::NNLayout::NCHW);
|
|
|
|
// NNPACK only supports stride == 1
|
|
for (auto stride : conv->getStrides()) {
|
|
if (stride != 1) {
|
|
validTransformCandidate = false;
|
|
break;
|
|
}
|
|
}
|
|
NOM_REQUIRE_OR_CONT(validTransformCandidate);
|
|
|
|
// NNPACK only supports 2DConv.
|
|
const auto& kernelShape = conv->getKernelShape();
|
|
NOM_REQUIRE_OR_CONT(kernelShape.size() == 2);
|
|
|
|
// Kx1 and 1xK convs are inefficient in NNPACK.
|
|
if (kernelShape[0] != kernelShape[1]) {
|
|
NOM_REQUIRE_OR_CONT(kernelShape[0] != 1 && kernelShape[1] != 1);
|
|
}
|
|
|
|
// We're good to use our engine.
|
|
auto annotation = conv->getMutableAnnotation();
|
|
NOM_REQUIRE_OR_CONT(annotation && isa<Caffe2Annotation>(annotation));
|
|
|
|
auto* op = dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
|
|
op->set_engine(engine);
|
|
if (!low_memory) {
|
|
auto* precompute_argument = op->add_arg();
|
|
precompute_argument->set_name("convolution_transform_strategy");
|
|
precompute_argument->set_s("PRECOMPUTE");
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
inline bool isNNPACKConvReluEfficient(
|
|
const std::string& algo,
|
|
const repr::Conv& conv) {
|
|
if (algo == "AUTO" || algo == "") {
|
|
for (auto stride : conv.getStrides()) {
|
|
if (stride > 1) {
|
|
return false;
|
|
}
|
|
}
|
|
for (auto kernel : conv.getKernelShape()) {
|
|
if (kernel < 2) {
|
|
return false;
|
|
}
|
|
}
|
|
} else if (!(algo == "WINOGRAD" || algo == "WINOGRAD_FP16" ||
|
|
algo == "FT8x8" || algo == "FT16x16")) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void fuseNNPACKConvRelu(repr::NNModule* nn) {
|
|
auto should_fuse = [](const repr::Conv& conv) {
|
|
const auto annotation = conv.getAnnotation();
|
|
if (!annotation || !isa<Caffe2Annotation>(annotation)) {
|
|
return false;
|
|
}
|
|
const auto& op = dyn_cast<Caffe2Annotation>(annotation)->getOperatorDef();
|
|
|
|
// We only want to fuse for fast NNPACK convs
|
|
if (op.engine() != "NNPACK") {
|
|
return false;
|
|
}
|
|
std::string algo = "AUTO";
|
|
for (const auto &arg : op.arg()) {
|
|
if (arg.name() == "algo") {
|
|
algo = arg.s();
|
|
}
|
|
}
|
|
if (!isNNPACKConvReluEfficient(algo, conv)) {
|
|
return false;
|
|
}
|
|
return true;
|
|
};
|
|
|
|
auto postprocess = [](repr::NNGraph::NodeRef conv_node) {
|
|
auto conv = repr::nn::get<repr::Conv>(conv_node);
|
|
auto annotation = conv->getMutableAnnotation();
|
|
if (!annotation || !isa<Caffe2Annotation>(annotation)) {
|
|
return;
|
|
}
|
|
auto* op = dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
|
|
auto* arg = op->add_arg();
|
|
arg->set_name("activation");
|
|
arg->set_s("Relu");
|
|
};
|
|
|
|
fuseActivation<repr::Conv, repr::Relu>(nn, should_fuse, postprocess);
|
|
}
|
|
|
|
REGISTER_OPT_PASS_FROM_FUNC(FuseNNPACKConvRelu, fuseNNPACKConvRelu);
|
|
REGISTER_OPT_PASS_FROM_FUNC(AddNNPACK, addNNPACK);
|
|
|
|
} // namespace opt
|
|
} // namespace caffe2
|